from darts.models.reservoirs.struct_reservoir import StructReservoir
from darts.models.physics.black_oil import BlackOil
from darts.models.darts_model import DartsModel
from darts.engines import value_vector, sim_params
import numpy as np
from darts.tools.keyword_file_tools import load_single_keyword
import os

class Model(DartsModel):

    def __init__(self, n_points=1000):
        # call base class constructor
        super().__init__()
        self.n_points = n_points
        # measure time spend on reading/initialization
        self.timer.node["initialization"].start()

        # create reservoir from UNISIM - 20 layers (81*58*20, Corner-point grid)
        permx = load_single_keyword('reservoir.in', 'PERMX')
        self.permx = permx[81*58*5:]
        permy = load_single_keyword('reservoir.in', 'PERMY')
        self.permy = permy[81*58*5:]
        permz = load_single_keyword('reservoir.in', 'PERMZ')
        self.permz = permz[81*58*5:]
        poro = load_single_keyword('reservoir.in', 'PORO')
        self.poro = poro[81*58*5:]
        dx = load_single_keyword('reservoir.in', 'DX')
        self.dx = dx[81*58*5:]
        dy = load_single_keyword('reservoir.in', 'DY')
        self.dy = dy[81*58*5:]
        dz = load_single_keyword('reservoir.in', 'DZ')
        self.dz = dz[81*58*5:]
        depth = load_single_keyword('reservoir.in', 'DEPTH')
        self.depth = depth[81*58*5:]

        # Import other properties from files
        actnum = load_single_keyword('reservoir.in', 'ACTNUM')
        self.actnum = actnum[81*58*5:]

        self.reservoir = StructReservoir(self.timer, nx=81, ny=58, nz=15, dx=self.dx, dy=self.dy, dz=self.dz,
                                         permx=self.permx, permy=self.permy, permz=self.permz, poro=self.poro,
                                         depth=self.depth, actnum=self.actnum, coord=0, zcorn=0)
        well_dia = 0.152
        well_rad = well_dia/2
        # """producers"""
        self.reservoir.add_well("NA1A", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i+1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 38, 36, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("NA2", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i+1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 21, 36, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("NA3D", wellbore_diameter=well_dia)
        for i in range (self.reservoir.nz):
            if (i+1) not in [4, 9, 11, 14, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 44, 43, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("RJS19", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i+1) not in [9, 12, 13, 14, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 31, 27, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("PROD005", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i+1) not in [5, 9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 33, 18, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("PROD008", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 19, 30, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("PROD009", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 15, 40, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("PROD010", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [4, 9, 13, 14, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 36, 42, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("PROD012", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 46, 23, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("PROD014", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 50, 18, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("PROD021", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 14, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 27, 41, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("PROD023A", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [8, 9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 65, 23, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("PROD024A", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 61, 35, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("PROD025A", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 57, 23, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ003", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 49, 23, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ005", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 31, 19, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ006", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 48, 34, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ007", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [4, 9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 59, 17, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ010", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 55, 30, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ015", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 36, 28, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ017", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 14, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 33, 39, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ019", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [8, 9, 14, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 29, 41, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ021", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 24, 28, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ022", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 48, 11, i+1, well_radius=well_rad, multi_segment=False)

        self.reservoir.add_well("INJ023", wellbore_diameter=well_dia)
        for i in range(self.reservoir.nz):
            if (i + 1) not in [4, 9, 15]:
                self.reservoir.add_perforation(self.reservoir.wells[-1], 42, 18, i+1, well_radius=well_rad, multi_segment=False)

        # physics definition
        self.log_based = False
        self.obl_min_comp = 1e-10
        self.n_points = 1000
        self.min_p = 20
        self.max_p = 400
        self.min_z = self.obl_min_comp
        self.max_z = 1 - self.obl_min_comp
        self.inj_stream = value_vector([1e-8, 1e-8])  # composition injection stream

        if self.log_based:
            min_z = self.min_z
            self.min_z = np.log(min_z)
            self.max_z = np.log(1 - min_z)
            self.inj_stream = np.log(self.inj_stream)

        self.physics = BlackOil(self.timer, 'pvt.in', self.n_points, self.min_p, self.max_p, self.min_z, self.max_z)

        self.params.first_ts = 1e-2
        self.params.mult_ts = 2
        self.params.max_ts = 15

        # Newton tolerance is relatively high because of L2-norm for residual and well segments
        self.params.tolerance_newton = 1e-3
        self.params.tolerance_linear = 1e-3
        self.params.max_i_newton = 20
        self.params.max_i_linear = 30
        self.params.newton_type = sim_params.newton_local_chop
        self.params.newton_params[0] = 0.2
        self.runtime = 900
        self.timer.node["initialization"].stop()

    def set_op_list(self):
        self.op_num = np.array(self.reservoir.mesh.op_num, copy=False)
        n_res = self.reservoir.mesh.n_res_blocks
        self.op_num[n_res:] = 1
        self.op_list = [self.physics.acc_flux_itor, self.physics.acc_flux_w_itor]

    def set_initial_conditions(self):
        if self.log_based:
            comp = np.log([0.07296, 0.4832])
        else:
            comp = np.array([0.07296, 0.4832])

        self.physics.set_uniform_initial_conditions(self.reservoir.mesh, 320, comp)

    def set_boundary_conditions(self):
        for i, w in enumerate(self.reservoir.wells):
            if i >= 14:
                w.control = self.physics.new_rate_water_inj(1200, self.inj_stream)
                w.constraint = self.physics.new_bhp_inj(343.2, self.inj_stream)
            else:
                w.control = self.physics.new_rate_oil_prod(800)
                w.constraint = self.physics.new_bhp_prod(35.3)

    def export_sat_vtk(self, file_name='data', local_cell_data={}, global_cell_data={}, vars_data_dtype=np.float32,
                   export_grid_data=True):

        # get current engine time
        t = self.physics.engine.t
        nb = self.reservoir.mesh.n_res_blocks
        nv = self.physics.n_vars
        X = np.array(self.physics.engine.X, copy=False)

        darts_vec1 = value_vector([0] * self.reservoir.nb)
        darts_vec = value_vector([0] * self.reservoir.nb)
        # create numpy wrapper around darts data
        np_vec1 = np.array(darts_vec1, copy=False)
        np_vec = np.array(darts_vec, copy=False)

        self.physics.bo_oil_sat_ev.evaluate(self.physics.engine.X, self.reservoir.nb, darts_vec)
        self.physics.bo_water_sat_ev.evaluate(self.physics.engine.X, self.reservoir.nb, darts_vec1)

        for v in range(nv):
            local_cell_data[self.physics.vars[v]] = X[v:nb * nv:nv].astype(vars_data_dtype)

        local_cell_data['So'] = np_vec.astype(vars_data_dtype)
        local_cell_data['Sw'] = np_vec1.astype(vars_data_dtype)

        self.reservoir.export_vtk(file_name, t, local_cell_data, global_cell_data, export_grid_data)

