from darts.models.reservoirs.struct_reservoir import StructReservoir
from darts.models.physics.black_oil import BlackOil
from darts.models.darts_model import DartsModel
from darts.engines import value_vector, sim_params
import numpy as np

class Model(DartsModel):

    def __init__(self, n_points):
        # call base class constructor
        super().__init__()

        # measure time spend on reading/initialization
        self.timer.node["initialization"].start()

        # create reservoir

        nb = 10
        self.depth = np.zeros(10)
        self.depth[0] = 1000
        for b in range(1, nb):
            self.depth[b] = self.depth[b - 1] + 10

        self.reservoir = StructReservoir(self.timer, nx=1, ny=1, nz=10, dx=10, dy=10,
                                         dz=10, permx=10, permy=10, permz=100,
                                         poro=0.2, depth=self.depth)

        # physics definition
        self.log_based = False
        self.obl_min_comp = 1e-13
        self.n_points = n_points
        self.min_p = 200
        self.max_p = 400
        self.min_z = self.obl_min_comp
        self.max_z = 1 - self.obl_min_comp
        self.inj_stream = value_vector([1e-10, 1e-10])  # composition injection stream

        if self.log_based:
            min_z = self.min_z
            self.min_z = np.log(min_z)
            self.max_z = np.log(1 - min_z)
            self.inj_stream = np.log(self.inj_stream)

        self.physics = BlackOil(self.timer, 'physics.in', self.n_points, self.min_p, self.max_p, self.min_z, self.max_z)
        self.params.first_ts = 1e-4
        self.params.mult_ts = 2
        self.params.max_ts = 10


        # Newton tolerance is relatively high because of L2-norm for residual and well segments
        self.params.tolerance_newton = 1e-3
        self.params.tolerance_linear = 1e-3
        self.params.max_i_newton = 20
        self.params.max_i_linear = 10
        self.params.newton_type = sim_params.newton_local_chop
        self.params.newton_params[0] = 0.2

        self.runtime = 1000
        # self.physics.engine.silent_mode = 0
        self.timer.node["initialization"].stop()

    def set_initial_conditions(self):
        zg = np.array([7.444e-012,7.472e-012,7.501e-012,7.505e-012,7.505e-012,0.1053,0.1057,0.1062,0.1066,0.1071])
        zo = np.array([6.482e-011,6.475e-011,6.469e-011,6.432e-011,6.402e-011,0.8947,0.8943,0.8938,0.8934,0.8929])
        pres = np.array([260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, 268.0, 269.0])
        self.physics.set_nonuniform_initial_conditions(self.reservoir.mesh, pres, zg, zo)

    def set_boundary_conditions(self):
        for i, w in enumerate(self.reservoir.wells):
            if i == 0:
                w.control = self.physics.new_bhp_inj(350, self.inj_stream)
                #w.control = self.physics.new_rate_gas_inj(3000)
                # w.constraint = self.physics.new_bhp_inj(400)
            else:
                w.control = self.physics.new_bhp_prod(200)
                # w.control = self.physics.new_rate_oil_prod(3000)
                # w.constraint = self.physics.new_bhp_prod(70)

    def set_op_list(self):
        self.op_num = np.array(self.reservoir.mesh.op_num, copy=False)
        n_res = self.reservoir.mesh.n_res_blocks
        self.op_num[n_res:] = 1
        self.op_list = [self.physics.acc_flux_itor, self.physics.acc_flux_w_itor]

    def export_sat_vtk(self, file_name='data', local_cell_data={}, global_cell_data={}, vars_data_dtype=np.float32,
                   export_grid_data=True):

        # get current engine time
        t = self.physics.engine.t
        nb = self.reservoir.mesh.n_res_blocks
        nv = self.physics.n_vars
        X = np.array(self.physics.engine.X, copy=False)

        darts_vec1 = value_vector([0] * self.reservoir.nb)
        darts_vec = value_vector([0] * self.reservoir.nb)
        # create numpy wrapper around darts data
        np_vec1 = np.array(darts_vec1, copy=False)
        np_vec = np.array(darts_vec, copy=False)

        self.physics.bo_oil_sat_ev.evaluate(self.physics.engine.X, self.reservoir.nb, darts_vec)
        self.physics.bo_water_sat_ev.evaluate(self.physics.engine.X, self.reservoir.nb, darts_vec1)

        for v in range(nv):
            local_cell_data[self.physics.vars[v]] = X[v:nb * nv:nv].astype(vars_data_dtype)

        local_cell_data['So'] = np_vec.astype(vars_data_dtype)
        local_cell_data['Sw'] = np_vec1.astype(vars_data_dtype)

        self.reservoir.export_vtk(file_name, t, local_cell_data, global_cell_data, export_grid_data)