from __future__ import print_function, division
from scipy.interpolate import RectBivariateSpline
import numpy as np
from numpy.lib.stride_tricks import as_strided as ast
from matplotlib.mlab import griddata


def grid_interp(x, y, Z, px, py, mode='nearest', **kwargs):
    modedict = dict(nearest=grid_interp_nearest,
                    linear=grid_interp_linear,
                    spline=grid_interp_spline)
    try:
        fn = modedict[mode]
    except KeyError:
        raise ValueError('mode not in {}'.format(tuple(modedict.keys())))

    return fn(x, y, Z, px, py, **kwargs)


def grid2grid(x, y, Z, xi, yi, mask=None, **kwargs):
    Xi, Yi = np.meshgrid(xi, yi)
    Zi = grid_interp(x, y, Z, Xi.flatten(), Yi.flatten(), **kwargs)
    out = Zi.reshape(Xi.shape)
    if mask is True:
        out = np.ma.masked_invalid(out)
    elif mask is not None:
        out = np.ma.MaskedArray(out, mask)
    return out


def grid2grid2(x, y, Z, xi, yi, mask=None, **kwargs):
    X, Y = np.meshgrid(x, y)
    if mask is None:
        if isinstance(Z, np.ma.MaskedArray):
            mask = Z.mask
        else:
            mask = np.isnan(Z)
    return points2grid(X[mask], Y[mask], Z[mask], xi, yi)


def grid2grid_3D(x, y, t, data, xi, yi, mask=None):
    """
    interpolate a grid to new grid points using linear interpolation
    :param x:
    :param y:
    :param t:
    :param data:
    :param xi:
    :param yi:
    :param mask:
    :return:
    """
    shape = (t.size, yi.size, xi.size)
    if mask is None:
        mask = np.zeros(shape[1:], dtype=bool)
    out = np.ma.MaskedArray(np.zeros(shape, dtype=data.dtype), mask=np.ones(shape, dtype=bool))

    X, Y = np.meshgrid(xi, yi)
    px = X.reshape(-1)
    py = Y.reshape(-1)

    for i in range(t.size):
        idata = RectBivariateSpline(y, x, data[i, :, :]).ev(py, px).reshape(X.shape)

        if isinstance(idata, np.ma.MaskedArray):
            idata.mask = idata.mask | mask
        else:
            idata = np.ma.MaskedArray(idata, mask=mask)

        out[i, :, :] = idata

    return out


def grid_interp_spline(x, y, Z, px, py):
    return RectBivariateSpline(y, x, Z, kx=5, ky=5).ev(py, px)


def grid_interp_linear(x, y, Z, px, py):
    return RectBivariateSpline(y, x, Z, kx=1, ky=1).ev(py, px)


def grid_interp_nearest(x, y, Z, px, py):
    out = np.zeros(px.shape)
    for i in range(len(px)):
        nearest_i = np.argmin(np.absolute(y - py[i].reshape(-1, 1)), axis=1)
        nearest_j = np.argmin(np.absolute(x - px[i].reshape(-1, 1)), axis=1)
        out[i] = Z[nearest_i, nearest_j]
    return out


def points2grid(px, py, pz, x, y):
    Z = np.zeros((y.size, x.size))+np.nan
    N = np.zeros((y.size, x.size))

    dx = np.absolute(px.reshape(1, -1) - x.reshape(-1, 1))
    ix = np.argmin(dx, axis=0)

    dy = np.absolute(py.reshape(1, -1) - y.reshape(-1, 1))
    iy = np.argmin(dy, axis=0)

    I = np.array([iy, ix]).T
    for i, ind in enumerate(I):
        # get number of values
        n = N[ind[0], ind[1]]

        if n == 0:
            Z[ind[0], ind[1]] = pz[i]
        else:
            # adjust values
            # new = old*n/(n+1) + new/(n+1)
            Z[ind[0], ind[1]] = Z[ind[0], ind[1]]*n/(n+1) + pz[i]/(n+1)

        # increment N
        N[ind[0], ind[1]] = n + 1
    return np.ma.masked_invalid(Z)


def points2gridobj(px, py, pz, cellsize=None, xlim=None, ylim=None):
    if cellsize is None:
        xdiff = np.diff(np.sort(px.copy()))
        dx = xdiff[xdiff > 0].min()

        ydiff = np.diff(np.sort(py.copy()))
        dy = ydiff[ydiff > 0].min()
    elif isinstance(cellsize, (int, float)):
        dx = dy = cellsize
    elif isinstance(cellsize, tuple):
        dx, dy = cellsize
    else:
        raise TypeError('unknown cellsize type')

    if xlim is None:
        xlim = px.min(), px.max()
    if ylim is None:
        ylim = py.min(), py.max()

    x = np.arange(xlim[0], xlim[1]+dx, dx)
    y = np.arange(ylim[0], ylim[1]+dy, dy)
    Z = points2grid(px, py, pz, x, y)
    from ..datasets import GridDataset
    return GridDataset(x, y, Z)


def aggregate(A, blocksize):
        input_shape = list(A.shape)
        output_shape = [item//blocksize for item in input_shape]
        block_shape = output_shape+[blocksize]*A.ndim

        if A.ndim == 1:
            itemstrides = (blocksize, 1)
        elif A.ndim == 2:
            itemstrides = (input_shape[1] * blocksize, blocksize, input_shape[1], 1)
        else:
            raise NotImplementedError('only 1D and 2D arrays supported')

        bytestrides = np.array(itemstrides)*A.itemsize

        return ast(A, shape=block_shape, strides=bytestrides).reshape(output_shape+[-1])


grid2pnt = grid_interp
pnt2grid = griddata