# region import

import matplotlib.pyplot as mpl


# endregion
# --------------------------------------------------------------------------------------------------------------------
# region 3D piecewise linear approx. (meshing)

def ddd_lin_approx(ddd_nl_func, bp_num_x, bp_num_y, x_limit_up, y_limit_up, x_limit_low, y_limit_low):
    bp_index_x = list(range(1, bp_num_x + 1, 1))  # indexing for number of bp's on x-axis
    bp_index_y = list(range(1, bp_num_y + 1, 1))  # indexing for number of bp's on y-axis

    x_bp = []  # create empty list for x
    y_bp = []  # create empty list for y

    for i in bp_index_x:
        if i == 1:
            x_bp.append(x_limit_low)
        else:
            if i == bp_num_x:
                x_bp.append(x_limit_up)
            else:
                x_bp.append(
                    ((x_limit_up - x_limit_low) / (bp_num_x - 1)) * (
                            i - 1) + x_limit_low)  # equally distributed bp's on x-axis

    for j in bp_index_y:
        if j == 1:
            y_bp.append(y_limit_low)
        else:
            if j == bp_num_y:
                y_bp.append(y_limit_up)
            else:
                y_bp.append(
                    ((y_limit_up - y_limit_low) / (bp_num_y - 1)) * (
                            j - 1) + y_limit_low)  # equally distributed bp's on y-axis

    x_bp_count = []  # create empty list for x (real surf)
    y_bp_count = []  # create empty list for y (real surf)
    z_real_value = []  # create empty list for z-value

    for i in bp_index_x:
        for j in bp_index_y:
            x_bp_count.append(x_bp[i - 1])  # x-counter for plot (real surf)
            y_bp_count.append(y_bp[j - 1])  # y-counter for plot (real surf)
            z_real_value.append(ddd_nl_func(x_bp, y_bp, i, j))  # inserts NL-fnc. here

    # region processing for plot

    x_bp_count_approx = x_bp_count + x_bp_count[
                                     bp_num_y:]  # real surf + surf shift on x-axis (cutting first row x-values)
    y_bp_count_approx = y_bp_count + y_bp_count[:(len(y_bp_count) - bp_num_y)]  # real surf + cutting last row y-values
    z_approx_surf = z_real_value + z_real_value[:(len(y_bp_count) - bp_num_y)]  # real surf + cutting last row z-values

    z_step = []  # create empty list for step-plots
    z_grid = z_real_value[:(len(y_bp_count) - bp_num_y)]  # second addend of z_approx_surf

    plot = mpl.axes(projection='3d')

    for j in bp_index_y:
        for i in bp_index_x:
            z_step.append(z_real_value[bp_num_y * (i - 1) + (j - 1)])
        mpl.step(x_bp,
                 z_step[((j - 1) * bp_num_x):], zs=y_bp[j - 1], zdir='y', where='post',
                 c='black')

    for i in bp_index_x[:bp_num_x - 1]:
        mpl.plot(y_bp, z_grid[(i - 1) * bp_num_y:(i * bp_num_y)], zs=x_bp[i - 1], zdir='x', c='black')
        mpl.plot(y_bp, z_grid[(i - 1) * bp_num_y:(i * bp_num_y)], zs=x_bp[i], zdir='x', c='black')

    plot.scatter3D(x_bp_count_approx, y_bp_count_approx, z_approx_surf, s=2, c='gold')
    plot.plot_trisurf(x_bp_count, y_bp_count, z_real_value, color='azure')
    plot.set_xlabel('x')
    plot.set_ylabel('y')
    plot.set_zlabel('z')
    # plot.set_zlim([0, 1])
    # plot.xaxis.set_ticks([])
    # plot.yaxis.set_ticks([])
    # plot.zaxis.set_ticks([])

    # endregion

    # region dict transformation for pyomo (data preparation)

    index_z_pyo = list(
        range(1, ((bp_num_x - 1) * bp_num_y) + 1,
              1))  # index (x-1) * (y-1) long for z values, not for pyomo explicit

    x = dict(zip(bp_index_x, x_bp))
    y = dict(zip(bp_index_y, y_bp))
    z = dict(zip(index_z_pyo, z_real_value[:(len(y_bp_count) - bp_num_y)]))

    # endregion

    return x, y, z,  # mpl.show()

# endregion
# --------------------------------------------------------------------------------------------------------------------
