# from reservoir import Reservoir
from darts.models.reservoirs.struct_reservoir import StructReservoir
from darts.models.physics.black_oil import BlackOil
from darts.models.darts_model import DartsModel
from darts.engines import value_vector, sim_params
import numpy as np
from darts.tools.keyword_file_tools import load_single_keyword
import os

class Model(DartsModel):

    def __init__(self):
        # call base class constructor
        super().__init__()

        # measure time spend on reading/initialization
        self.timer.node["initialization"].start()

        self.kx = load_single_keyword('permx.in', 'PERMX')
        self.ky = self.kx
        self.kz = self.kx*0.01
        self.poro = load_single_keyword('poro.in', 'PORO')
        self.dz = load_single_keyword('dz.in', 'DZ')
        self.depth = load_single_keyword('depth.in', 'DEPTH')

        self.reservoir = StructReservoir(self.timer, nx=24, ny=25, nz=15, dx=90, dy=90,
                                         dz=self.dz, permx=self.kx, permy=self.ky, permz=self.kz,
                                         poro=self.poro, depth=self.depth)

        well_dia = 0.3048
        well_rad = well_dia/2
        self.reservoir.add_well("INJE1",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(well=self.reservoir.wells[-1], i=24, j=25, k=11, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(well=self.reservoir.wells[-1], i=24, j=25, k=12, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(well=self.reservoir.wells[-1], i=24, j=25, k=13, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(well=self.reservoir.wells[-1], i=24, j=25, k=14, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(well=self.reservoir.wells[-1], i=24, j=25, k=15, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_well("PROD2",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 5, 1, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 5, 1, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 5, 1, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells = [self.reservoir.wells[-1]]
        self.reservoir.add_well("PROD3",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 2, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 2, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 2, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD4",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 3, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 3, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 3, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD5",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 10, 4, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 10, 4, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 10, 4, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD6",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 5, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 5, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 5, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD7",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 4, 6, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 4, 6, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 4, 6, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD8",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 7, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 7, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 7, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD9",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 14, 8, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 14, 8, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 14, 8, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD10",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 9, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 9, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 9, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD11",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 10, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 10, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 10, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD12",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 10, 11, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 10, 11, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 10, 11, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD13",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 5, 12, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 5, 12, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 5, 12, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD14",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 13, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 13, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 13, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD15",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 14, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 14, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 14, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD16",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 13, 15, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 13, 15, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 13, 15, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD17",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 15, 16, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 15, 16, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 15, 16, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD18",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 17, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 17, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 17, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD19",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 18, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 18, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 18, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD20",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 5, 19, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 5, 19, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 5, 19, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD21",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 20, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 20, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 8, 20, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD22",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 21, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 21, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 11, 21, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD23",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 15, 22, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 15, 22, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 15, 22, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD24",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 23, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 23, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 12, 23, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD25",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 10, 24, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 10, 24, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 10, 24, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])
        self.reservoir.add_well("PROD26",wellbore_diameter=well_dia)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 17, 25, 2, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 17, 25, 3, well_radius=well_rad, multi_segment=False)
        self.reservoir.add_perforation(self.reservoir.wells[-1], 17, 25, 4, well_radius=well_rad, multi_segment=False)
        self.reservoir.prod_wells.append(self.reservoir.wells[-1])

        # physics definition
        self.log_based = False
        self.obl_min_comp = 1e-8
        self.n_points = 1000
        self.min_p = 50
        self.max_p = 500
        self.min_z = self.obl_min_comp
        self.max_z = 1 - self.obl_min_comp
        self.inj_stream = value_vector([1e-6, 1e-6])  # composition injection stream

        if self.log_based:
            min_z = self.min_z
            self.min_z = np.log(min_z)
            self.max_z = np.log(1 - min_z)
            self.inj_stream = np.log(self.inj_stream)

        self.physics = BlackOil(self.timer, 'physics.in', self.n_points, self.min_p, self.max_p, self.min_z, self.max_z)
        self.params.first_ts = 1e-4
        self.params.mult_ts = 2
        self.params.max_ts = 10

        # Newton tolerance is relatively high because of L2-norm for residual and well segments
        self.params.tolerance_newton = 1e-2
        self.params.tolerance_linear = 1e-4
        self.params.max_i_newton = 20
        self.params.max_i_linear = 50
        self.params.newton_type = sim_params.newton_global_chop
        self.params.newton_params[0] = 0.15
        self.params.linear_type = sim_params.cpu_gmres_cpr_amg
        self.runtime = 900
        if self.log_based:
            self.params.log_transform = 1
        else:
            self.params.log_transform = 0
        self.timer.node["initialization"].stop()

    def set_op_list(self):
        self.op_num = np.array(self.reservoir.mesh.op_num, copy=False)
        n_res = self.reservoir.mesh.n_res_blocks
        self.op_num[n_res:] = 1
        self.op_list = [self.physics.acc_flux_itor, self.physics.acc_flux_w_itor]

    def set_initial_conditions(self):
        n_blocks = self.reservoir.mesh.n_blocks
        n_res = self.reservoir.mesh.n_res_blocks
        p_well = [289.97, 289.97, 253.703, 253.703, 257.974, 257.974, 262.245, 262.245, 260.821, 260.821, 263.669, 263.669,
             252.28, 252.28, 257.974, 257.974, 266.516, 266.516, 262.245, 262.245, 263.669, 263.669, 260.821, 260.821,
             253.703, 253.703, 257.974, 257.974, 262.796, 262.796, 265.644, 265.644, 268.491, 268.491, 262.796, 262.796,
             264.22, 264.22, 254.255, 254.255, 258.526, 258.526, 262.796, 262.796, 268.491, 268.491, 264.22, 264.22,
             261.373, 261.373, 271.339, 271.339]
        zg_well = [0.00299, 0.00299, 0.2256, 0.2256, 0.2216, 0.2216, 0.2097, 0.2097, 0.2166, 0.2166, 0.2019, 0.2019,
                   0.2264, 0.2264, 0.2216, 0.2216, 0.1737, 0.1737, 0.2097, 0.2097, 0.2019, 0.2019, 0.2166, 0.2166,
                   0.2256, 0.2256, 0.2216, 0.2216, 0.2068, 0.2068, 0.1784, 0.1784, 0.173, 0.173, 0.2068, 0.2068,
                   0.1959, 0.1959, 0.2253, 0.2253, 0.2207, 0.2207, 0.2068, 0.2068, 0.173, 0.173, 0.1959, 0.1959,
                   0.2142, 0.2142, 0.172, 0.172]
        zo_well = [0.007745, 0.007745, 0.5844, 0.5844, 0.5741, 0.5741, 0.543, 0.543, 0.5611, 0.5611, 0.5228, 0.5228, 0.5863,
              0.5863, 0.5741, 0.5741, 0.4499, 0.4499, 0.543, 0.543, 0.5228, 0.5228, 0.5611, 0.5611, 0.5844, 0.5844,
              0.5741, 0.5741, 0.5356, 0.5356, 0.462, 0.462, 0.4482, 0.4482, 0.5356, 0.5356, 0.5074, 0.5074, 0.5836,
              0.5836, 0.5716, 0.5716, 0.5356, 0.5356, 0.4482, 0.4482, 0.5074, 0.5074, 0.5548, 0.5548, 0.4454, 0.4454]

        self.zg = load_single_keyword('zg.in', 'ZG')
        self.zg = np.append(self.zg, zg_well)
        self.zo = load_single_keyword('zo.in', 'ZO')
        self.zo = np.append(self.zo, zo_well)
        if self.log_based:
            self.zg = np.log(self.zg)
            self.zo = np.log(self.zo)
        self.pressure = load_single_keyword('Pressure.in', 'PRES')
        self.pressure = np.append(self.pressure, p_well)

        self.physics.set_nonuniform_initial_conditions(self.reservoir.mesh, self.pressure, self.zg, self.zo)

    def set_boundary_conditions(self):
        for i, w in enumerate(self.reservoir.wells):
            if i == 0:
                # w.control = self.physics.new_bhp_inj(300, self.inj_stream)
                w.control = self.physics.new_rate_water_inj(795, self.inj_stream)
                w.constraint = self.physics.new_bhp_inj(300, self.inj_stream)
            else:
                # w.control = self.physics.new_bhp_prod(40)
                w.control = self.physics.new_rate_oil_prod(200)
                w.constraint = self.physics.new_bhp_prod(70)

    def change_pro_rate(self, new_rate):
        for i, w in enumerate(self.reservoir.wells):
            if i >0:
                w.control = self.physics.new_rate_oil_prod(new_rate)
                w.constraint = self.physics.new_bhp_prod(70)


    def export_sat_vtk(self, file_name='data', local_cell_data={}, global_cell_data={}, vars_data_dtype=np.float32,
                   export_grid_data=True):

        # get current engine time
        t = self.physics.engine.t
        nb = self.reservoir.mesh.n_res_blocks
        nv = self.physics.n_vars
        X = np.array(self.physics.engine.X, copy=False)

        darts_vec1 = value_vector([0] * self.reservoir.nb)
        darts_vec = value_vector([0] * self.reservoir.nb)
        # create numpy wrapper around darts data
        np_vec1 = np.array(darts_vec1, copy=False)
        np_vec = np.array(darts_vec, copy=False)

        self.physics.bo_oil_sat_ev.evaluate(self.physics.engine.X, self.reservoir.nb, darts_vec)
        self.physics.bo_water_sat_ev.evaluate(self.physics.engine.X, self.reservoir.nb, darts_vec1)

        for v in range(nv):
            local_cell_data[self.physics.vars[v]] = X[v:nb * nv:nv].astype(vars_data_dtype)

        local_cell_data['So'] = np_vec.astype(vars_data_dtype)
        local_cell_data['Sw'] = np_vec1.astype(vars_data_dtype)

        self.reservoir.export_vtk(file_name, t, local_cell_data, global_cell_data, export_grid_data)