from darts.engines import value_vector, sim_params
from darts.models.reservoirs.struct_reservoir import StructReservoir
from darts.models.physics.dead_oil import DeadOil
from darts.models.darts_model import DartsModel
from darts.tools.keyword_file_tools import load_single_keyword
import numpy as np


class Model(DartsModel):
    def __init__(self, n_points):
        # call base class constructor
        super().__init__()

        # measure time spend on reading/initialization
        self.timer.node["initialization"].start()

        # create reservoir from SPE10 - 2 layers (60x220x2 blocks, 20 ft x 10 ft x 2 ft grid block size)
        self.permx = load_single_keyword('permx.in', 'PERMX')
        self.permy = load_single_keyword('permy.in', 'PERMY')
        self.permz = load_single_keyword('permz.in', 'PERMZ')
        self.poro = load_single_keyword('poro.in', 'PORO')
        self.depth = load_single_keyword('depth.in', 'DEPTH')

        self.reservoir = StructReservoir(self.timer, nx=60, ny=220, nz=1, dx=6.096, dy=3.048, dz=0.6096, permx=self.permx,
                                         permy=self.permy, permz=self.permz, poro=self.poro, depth=self.depth)
        well_dia = 0.3048
        well_rad = well_dia/2
        self.reservoir.add_well("I1", wellbore_diameter=well_dia)
        self.reservoir.add_perforation(well=self.reservoir.wells[-1], i=30, j=110, k=1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("P1", wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 1, 1, 1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("P2", wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 60, 1, 1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("P3", wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 1, 220, 1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("P4", wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 60, 220, 1, well_radius=well_rad, multi_segment=False)

        # create physics
        self.physics = DeadOil(timer=self.timer, physics_filename='physics.in', n_points=n_points, min_p=0, max_p=800,
                               min_z=1e-12)

        self.params.first_ts = 0.01
        self.params.mult_ts = 2
        self.params.max_ts = 5
        self.params.tolerance_newton = 1e-3
        self.params.tolerance_linear = 1e-3
        self.params.newton_type = sim_params.newton_local_chop
        self.params.newton_params = value_vector([0.2])

        self.runtime = 2000

        self.timer.node["initialization"].stop()

    def set_initial_conditions(self):
        self.physics.set_uniform_initial_conditions(self.reservoir.mesh, uniform_pressure=250, uniform_composition=[0.2357])

    def set_op_list(self):
        self.op_num = np.array(self.reservoir.mesh.op_num, copy=False)
        n_res = self.reservoir.mesh.n_res_blocks
        self.op_num[n_res:] = 1
        self.op_list = [self.physics.acc_flux_itor, self.physics.acc_flux_w_itor]

    def set_boundary_conditions(self):
        for i, w in enumerate(self.reservoir.wells):
            if i == 0:
                w.control = self.physics.new_rate_water_inj(5)
            else:
                w.control = self.physics.new_bhp_prod(150)

    def export_sat_vtk(self, file_name='data', local_cell_data={}, global_cell_data={}, vars_data_dtype=np.float32,
                   export_grid_data=True):

        # get current engine time
        t = self.physics.engine.t
        nb = self.reservoir.mesh.n_res_blocks
        nv = self.physics.n_vars
        X = np.array(self.physics.engine.X, copy=False)

        darts_vec = value_vector([0] * self.reservoir.nb)
        # create numpy wrapper around darts data
        np_vec = np.array(darts_vec, copy=False)

        self.physics.do_water_sat_ev.evaluate(self.physics.engine.X, self.reservoir.nb, darts_vec)

        for v in range(nv):
            local_cell_data[self.physics.vars[v]] = X[v:nb * nv:nv].astype(vars_data_dtype)

        local_cell_data['So'] = (1-np_vec).astype(vars_data_dtype)
        local_cell_data['Sw'] = np_vec.astype(vars_data_dtype)

        self.reservoir.export_vtk(file_name, t, local_cell_data, global_cell_data, export_grid_data)
