import openpnm as op
import numpy as np
import TiffUtils
import vtk
import os
import glob

import pandas as pd
from xml.etree import ElementTree as ET
from openpnm.io import _parse_filename, project_to_vtk
from openpnm.utils import NestedDict, Workspace

_TEMPLATE = """
<?xml version="1.0" ?>
<VTKFile byte_order="LittleEndian" type="PolyData" version="0.1">
    <PolyData>
        <Piece NumberOfLines="0" NumberOfPoints="0">
            <Points>
            </Points>
            <Lines>
            </Lines>
            <PointData>
            </PointData>
            <CellData>
            </CellData>
        </Piece>
    </PolyData>
</VTKFile>
""".strip()

def _array_to_element(name, array, n=1):
    dtype_map = {
        "int8": "Int8",
        "int16": "Int16",
        "int32": "Int32",
        "int64": "Int64",
        "uint8": "UInt8",
        "uint16": "UInt16",
        "uint32": "UInt32",
        "uint64": "UInt64",
        "float32": "Float32",
        "float64": "Float64",
        "str": "String",
    }
    element = None
    if str(array.dtype) in dtype_map.keys():
        element = ET.Element("DataArray")
        element.set("Name", name)
        element.set("NumberOfComponents", str(n))
        element.set("type", dtype_map[str(array.dtype)])
        element.text = "\t".join(map(str, array.ravel()))
    return element

def network_to_dict(network, categorize_by=['name'], flatten=False, element=None,
                    delim=' | '):
    r"""
    Returns a single dictionary object containing data from the given
    OpenPNM project, with the keys organized differently depending on
    optional arguments.

    Parameters
    ----------
    project : list
        An OpenPNM project object
    categorize_by : str or list[str]
        Indicates how the dictionaries should be organized.  The list can
        contain any, all or none of the following strings:

        **'object'** : If specified the dictionary keys will be stored
        under a general level corresponding to their type (e.g.
        'network/net_01/pore.all').

        **'name'** : If specified, then the data arrays are additionally
        categorized by their name.  This is enabled by default.

        **'data'** : If specified the data arrays are additionally
        categorized by ``label`` and ``property`` to separate *boolean*
        from *numeric* data.

        **'element'** : If specified the data arrays are
        additionally categorized by ``pore`` and ``throat``, meaning
        that the propnames are no longer prepended by a 'pore.' or
        'throat.'

    Returns
    -------
    A dictionary with the data stored in a hierarchical data structure, the
    actual format of which depends on the arguments to the function.

    """

    if flatten:
        d = {}
    else:
        d = NestedDict(delimiter=delim)

    def build_path(obj, key):
        propname = key
        name = ''
        prefix = ''
        datatype = ''
        arr = obj[key]
        if 'object' in categorize_by:
            if hasattr(obj, 'coords'):
                prefix = 'network' + delim
            else:
                prefix = 'phase' + delim
        if 'element' in categorize_by:
            propname = key.replace('.', delim)
        if 'data' in categorize_by:
            if arr.dtype == bool:
                datatype = 'labels' + delim
            else:
                datatype = 'properties' + delim
        if 'name' in categorize_by:
            name = obj.name + delim
        path = prefix + name + datatype + propname
        return path

    for key in network.props(element=element) + network.labels(element=element):
        path = build_path(obj=network, key=key)
        d[path] = network[key]

    return d


def network_to_vtk(network, filename, fill_nans=None, fill_infs=None):
    r"""
    Writes network to a vtk file, adapted from OpenPNM

    Arguments
    ---------
    network:
        OpenPNM network or dictionary
    filename:
        output file name
    fill_nans:
        default value to substitute NaNs
    fill_infs:
        default value to fill infs
    """
    # Check if any of the phases has time series
    if filename == "":
        raise("no filename provided!")
    filename = _parse_filename(filename=filename, ext="vtp")

    am = network_to_dict(network=network,
                         categorize_by=["object", "data"])
    am = pd.json_normalize(am, sep='.').to_dict(orient='records')[0]
    for k in list(am.keys()):
        am[k.replace('.', ' | ')] = am.pop(k)
    key_list = list(sorted(am.keys()))

    points = network["pore.coords"]
    pairs = network["throat.conns"]
    num_points = np.shape(points)[0]
    num_throats = np.shape(pairs)[0]

    root = ET.fromstring(_TEMPLATE)
    piece_node = root.find("PolyData").find("Piece")
    piece_node.set("NumberOfPoints", str(num_points))
    piece_node.set("NumberOfLines", str(num_throats))
    points_node = piece_node.find("Points")
    coords = _array_to_element("coords", points.T.ravel("F"), n=3)
    points_node.append(coords)
    lines_node = piece_node.find("Lines")
    connectivity = _array_to_element("connectivity", pairs)
    lines_node.append(connectivity)
    offsets = _array_to_element("offsets", 2 * np.arange(len(pairs)) + 2)
    lines_node.append(offsets)

    point_data_node = piece_node.find("PointData")
    cell_data_node = piece_node.find("CellData")
    for key in key_list:
        array = am[key]
        if array.dtype == bool:
            array = array.astype(int)
        if np.any(np.isnan(array)):
            if fill_nans is None:
                print(key + " has nans," + " will not write to file")
                continue
            else:
                array[np.isnan(array)] = fill_nans
        if np.any(np.isinf(array)):
            if fill_infs is None:
                print(key + " has infs," + " will not write to file")
                continue
            else:
                array[np.isinf(array)] = fill_infs
        element = _array_to_element(key, array)
        if array.size == num_points:
            point_data_node.append(element)
        elif array.size == num_throats:
            cell_data_node.append(element)

    tree = ET.ElementTree(root)
    tree.write(filename)

    with open(filename, "r+") as f:
        string = f.read()
        string = string.replace("</DataArray>", "</DataArray>\n\t\t\t")
        f.seek(0)
        # consider adding header: '<?xml version="1.0"?>\n'+
        f.write(string)


def WriteToVTK(particles: vtk.vtkAppendPolyData, out_file: str):
    r"""
    Writes an instance of vtkAppendPolydata to a file

    particles:
        instance of vtkAppendPolyData
    out_file:
        name of the file
    """
    particles.Update()
    writer = vtk.vtkPolyDataWriter()
    writer.SetInputData(particles.GetOutput())
    writer.SetFileName(out_file)
    writer.Update()


def WritePoresToVTK(coords, dpore, filename: str, quality: int, filter=None):
    r"""
    Writes pores as spheres in VTK format

    Arguments
    ---------
    coords:
        [Np, 3] array of center coordinates of the pores
    dpore:
        [Np] array with pore diameters
    filename:
        Name of the output file
    quality:
        quality indicator for the spheres, based on the vtk backend
    filter:
        Callable with signature (coord: list) -> bool, all values which are not true will not be added to the output
    """
    if coords.shape[0] != dpore.shape[0]:
        raise Exception('coordinates and pore incompatible')
    if (len(coords) == 0):
        print('No pores provided for writing, skipping writing of ' + filename)
        return

    all_spheres = vtk.vtkAppendPolyData()
    for i in range(len(dpore)):
        if not (filter(coords[i, :]) if filter is not None else True):
            continue
        sphere = vtk.vtkSphereSource()
        sphere.SetThetaResolution(quality)
        sphere.SetPhiResolution(quality)
        sphere.SetRadius(dpore[i]*0.5)
        sphere.SetCenter(coords[i, 0], coords[i, 1], coords[i, 2])
        sphere.Update()
        all_spheres.AddInputData(sphere.GetOutput())

    WriteToVTK(all_spheres, filename)


def WriteThroatsToVTK(coords, conns, radii, filename: str, quality: int, filter=None):
    r"""
    Writes throats as cylinders in VTK format

    Arguments
    ---------
    coords:
        [Np, 3] array of center coordinates of the pores
    conns:
        [Nt, 2] array of endpoints of the throats
    radii:
        [Np] array with throat radii
    filename:
        Name of the output file
    quality:
        quality indicator for the spheres, based on the vtk backend
    filter:
        Callable with signature (coord: list) -> bool, all values which are not true will not be added to the output
    """
    if conns.shape[0] != radii.shape[0]:
        raise Exception('radii and throats are incompatible')
    if len(coords) == 0 or len(radii) == 0:
        print('No coordinates provided for writing, skipping writing of ' + filename)
        return

    all_throats = vtk.vtkAppendPolyData()
    for i in range(len(radii)):
        line = vtk.vtkLineSource()
        p1 = conns[i, 0]
        p2 = conns[i, 1]
        if not (filter(coords[p1, :], coords[p2, :]) if filter is not None else True):
            continue

        line.SetPoint1(coords[p1, 0], coords[p1, 1], coords[p1, 2])
        line.SetPoint2(coords[p2, 0], coords[p2, 1], coords[p2, 2])
        line.SetResolution(1)
        line.Update()

        tubefilter = vtk.vtkTubeFilter()
        tubefilter.SetInputData(line.GetOutput())
        tubefilter.SetRadius(radii[i])
        tubefilter.SetNumberOfSides(quality)
        # tubefilter.CappingOn()
        tubefilter.Update()
        all_throats.AddInputData(tubefilter.GetOutput())

    WriteToVTK(all_throats, filename)


def WriteNetworkToVTK(network, filename: str, quality: int):
    r"""
    Writes the whole network to a VTK file, where the pores are given as spheres and throat as cylinders

    Arguments
    ---------
    filename:
        base name of the output file
    quality:
        quality indicator
    """
    coords = network['pore.coords']
    conns=network['throat.conns']
    if 'pore.diameter' in network:
        dpore = network['pore.diameter']
    elif 'pore.radius' in network:
        dpore = network['pore.radius']*2
    else:
        raise('Neither diameter nor radius are provided for the pore')

    base_path, _ = os.path.splitext(filename)

    WritePoresToVTK(filename=base_path+'_pores.vtk', coords=coords, dpore=dpore, quality=quality)
    WriteThroatsToVTK(filename=base_path+'_throats.vtk', coords=coords, conns=conns, radii=network['throat.radius'], quality=quality)


# path to files and base name
network_base_path = '<path_to_folder_with_extracted_network>'
prefix = '089_pore_space'
image_base_path = '<path_to_segmented_images>'
output_path = '<path_to_output>'

list_images = glob.glob(image_base_path+'*tiff')

count = 0
num_images = len(list_images)
for image_path in list_images:
    count += 1
    print(f'{count}/{num_images} - {image_path}')
    # determine image properties and prepare filenames
    base_image_path, _ = os.path.splitext(image_path)
    tstep = base_image_path.split('_')[-1]

    # create file names
    fname_network = output_path + 'network_' + tstep
    fname_throats_p1 = output_path + 'throats_p1_' + tstep + '.vtk'
    fname_throats_p2 = output_path + 'throats_p2_' + tstep + '.vtk'
    fname_pores_p1 = output_path + 'pores_p1_' + tstep + '.vtk'
    fname_pores_p2 = output_path + 'pores_p2_' + tstep + '.vtk'

    # read in network
    network = op.io.network_from_statoil(path=network_base_path, prefix=prefix)

    # read in image
    image = TiffUtils.ReadTiffStack(image=image_path)
    image = np.swapaxes(image, 0, 1)

    # label pores by phase
    coords = network['pore.coords'].astype(int)
    coords[:, 0], coords[:, 2] = coords[:, 2], coords[:, 0].copy()
    num_pores = coords.shape[0]

    pore_phase = np.asarray([image[coords[i, 0], coords[i, 1], coords[i, 2]] for i in range(num_pores)])
    conns = network['throat.conns']
    throat_start = coords[conns[:, 0], :]
    throat_end = coords[conns[:, 1], :]
    throat_phase = (pore_phase[conns[:, 0]] + pore_phase[conns[:, 1]]) * 0.5

    network['pore.phase'] = pore_phase
    network['throat.start'] = throat_start
    network['throat.end'] = throat_end
    network['throat.phase'] = throat_phase

    # map also for output
    coords = network['pore.coords']
    coords[:, 0], coords[:, 2] = coords[:, 2], coords[:, 0].copy()

    def FilterPore(p, phase: int):
        point = p.astype(int)
        return image[point[0], point[1], point[2]] == phase

    def FilterPore1(p):
        return FilterPore(p, 1)

    def FilterPore2(p):
        return FilterPore(p, 2)

    def FilterThroat(p1, p2, phase: int):
        point1 = p1.astype(int)
        point2 = p2.astype(int)
        return (image[point1[0], point1[1], point1[2]] == phase) and (image[point2[0], point2[1], point2[2]] == phase)

    def FilterThroat1(p1, p2):
        return FilterThroat(p1, p2, 1)

    def FilterThroat2(p1, p2):
        return FilterThroat(p1, p2, 2)

    WriteThroatsToVTK(filename=fname_throats_p1,
                      coords=network["pore.coords"],
                      conns=network["throat.conns"],
                      radii=network['throat.radius'],
                      quality=20,
                      filter=FilterThroat1)
    WriteThroatsToVTK(filename=fname_throats_p2,
                      coords=network["pore.coords"],
                      conns=network["throat.conns"],
                      radii=network['throat.radius'],
                      quality=20,
                      filter=FilterThroat2)
    WritePoresToVTK(filename=fname_pores_p1,
                    coords=network['pore.coords'],
                    dpore=network['pore.radius'] * 2,
                    quality=40,
                    filter=FilterPore1)
    WritePoresToVTK(filename=fname_pores_p2,
                    coords=network['pore.coords'],
                    dpore=network['pore.radius'] * 2,
                    quality=40,
                    filter=FilterPore2)
    WriteNetworkToVTK(network=network, filename='vtknetwork.vtk', quality=20)

    network_to_vtk(network=network, filename=fname_network)
    op.io.network_to_csv(network=network, filename=fname_network)