from darts.models.reservoirs.struct_reservoir import StructReservoir
from darts.models.physics.black_oil import BlackOil
from darts.models.darts_model import DartsModel
from darts.engines import value_vector
import numpy as np

class Model(DartsModel):

    def __init__(self, n_points):
        # call base class constructor
        super().__init__()

        # measure time spend on reading/initialization
        self.timer.node["initialization"].start()

        # create reservoir
        self.kx = np.zeros(300)
        self.kx[0:100].fill(500)
        self.kx[100:200].fill(50)
        self.kx[200:300].fill(200)
        self.ky = self.kx
        self.kz = np.zeros(300)
        self.kz[0:100].fill(80)
        self.kz[100:200].fill(42)
        self.kz[200:300].fill(20)
        self.poro = 0.3
        self.dz = np.zeros(300)
        self.dz[0:100].fill(6)
        self.dz[100:200].fill(9)
        self.dz[200:300].fill(15)
        self.depth = np.zeros(300)
        self.depth[0:100].fill(2540)
        self.depth[100:200].fill(2546)
        self.depth[200:300].fill(2555)

        self.reservoir = StructReservoir(self.timer, nx=10, ny=10, nz=3, dx=300, dy=300,
                                         dz=self.dz, permx=self.kx, permy=self.ky, permz=self.kz,
                                         poro=self.poro, depth=self.depth)

        self.reservoir.add_well("I1")
        self.reservoir.add_perforation(well=self.reservoir.wells[-1], i=1, j=1, k=1, multi_segment=False)
        self.reservoir.add_well("P1")
        self.reservoir.add_perforation(self.reservoir.wells[-1], 10, 10, 3, multi_segment=False)

        # physics definition
        self.log_based = False
        self.obl_min_comp = 1e-16
        self.n_points = n_points
        self.min_p = 50
        self.max_p = 500
        self.min_z = self.obl_min_comp
        self.max_z = 1 - self.obl_min_comp
        self.inj_stream = value_vector([1-2e-8, 1e-8])  # composition injection stream

        if self.log_based:
            min_z = self.min_z
            self.min_z = np.log(min_z)
            self.max_z = np.log(1 - min_z)
            self.inj_stream = np.log(self.inj_stream)

        self.physics = BlackOil(self.timer, 'physics.in', self.n_points, self.min_p, self.max_p, self.min_z, self.max_z)

        self.params.first_ts = 0.01
        self.params.mult_ts = 2
        self.params.max_ts = 10

        # Newton tolerance is relatively high because of L2-norm for residual and well segments
        self.params.tolerance_newton = 1e-3
        self.params.tolerance_linear = 1e-3
        self.params.max_i_newton = 20
        self.params.max_i_linear = 10
        # self.params.newton_type = 1
        # self.params.newton_params[0] = 0.2

        self.runtime = 3650
        # self.physics.engine.silent_mode = 0
        self.timer.node["initialization"].stop()

    def set_initial_conditions(self):
        if self.log_based:
            comp = np.log([0.001225901537, 0.7711341309])
        else:
            comp = np.array([0.001225901537, 0.7711341309])

        self.physics.set_uniform_initial_conditions(self.reservoir.mesh, uniform_pressure=330, uniform_composition=comp)

    def set_boundary_conditions(self):
        for i, w in enumerate(self.reservoir.wells):
            if i == 0:
                w.control = self.physics.new_bhp_inj(400, self.inj_stream)
            else:
                w.control = self.physics.new_rate_oil_prod(3000)
                w.constraint = self.physics.new_bhp_prod(70)

    def set_op_list(self):
        self.op_num = np.array(self.reservoir.mesh.op_num, copy=False)
        n_res = self.reservoir.mesh.n_res_blocks
        self.op_num[n_res:] = 1
        self.op_list = [self.physics.acc_flux_itor, self.physics.acc_flux_w_itor]

    def export_sat_vtk(self, file_name='data', local_cell_data={}, global_cell_data={}, vars_data_dtype=np.float32,
                   export_grid_data=True):

        # get current engine time
        t = self.physics.engine.t
        nb = self.reservoir.mesh.n_res_blocks
        nv = self.physics.n_vars
        X = np.array(self.physics.engine.X, copy=False)

        darts_vec1 = value_vector([0] * self.reservoir.nb)
        darts_vec = value_vector([0] * self.reservoir.nb)
        # create numpy wrapper around darts data
        np_vec1 = np.array(darts_vec1, copy=False)
        np_vec = np.array(darts_vec, copy=False)

        self.physics.bo_oil_sat_ev.evaluate(self.physics.engine.X, self.reservoir.nb, darts_vec)
        self.physics.bo_water_sat_ev.evaluate(self.physics.engine.X, self.reservoir.nb, darts_vec1)

        for v in range(nv):
            local_cell_data[self.physics.vars[v]] = X[v:nb * nv:nv].astype(vars_data_dtype)

        local_cell_data['So'] = np_vec.astype(vars_data_dtype)
        local_cell_data['Sw'] = np_vec1.astype(vars_data_dtype)

        self.reservoir.export_vtk(file_name, t, local_cell_data, global_cell_data, export_grid_data)