from __future__ import division, print_function
import numpy as np
from scipy import ndimage
import time
import logging


class AnalysisError(Exception):
    pass


def _nextpow2(n):
    m_f = np.log2(n)
    m_i = int(np.ceil(m_f))
    return 2**m_i


def _gauss(n, m, steepness=5):
    M = np.array([np.arange(0, m)]*n)
    N = np.array([np.arange(0, n)]*m).T

    m0 = m/2
    n0 = n/2

    sm = m/steepness
    sn = n/steepness

    return np.exp(-((M-m0)**2/(2*sm**2)+(N-n0)**2/(2*sn**2)))


def prepare(data, unmask=False, taper=False, avg=False):
    try:
        mask = data.mask
    except AttributeError:
        mask = np.zeros(data.shape) == 1

    if taper:
        weights = _gauss(*data.shape)
    else:
        weights = np.ones(data.shape)

    if avg:
        wm = np.nansum(data*weights)/np.sum(weights[~mask])
        data = (data-wm)*weights
    else:
        data = data*weights

    if isinstance(data, np.ma.MaskedArray) and unmask:
        if avg:
            fill_value = 0
        else:
            fill_value = np.mean(data)
        data = data.filled(fill_value)

    return data


def get_limits(v, rf=None):
    if not isinstance(v, np.ndarray):
        v = np.array(v)
    if v.shape != (2,):
        v = np.array([np.amin(v), np.amax(v)])

    order = np.floor(np.log10(np.absolute(v))).max()
    scale = v/10**order

    if rf is None:
        rf = .2

    lower = 10**order * np.floor(scale[0]/rf)*rf
    upper = 10**order * np.ceil(scale[1]/rf)*rf
    return lower, upper


def crop(data, mask=None):
    if mask is None:
        if isinstance(data, np.ma.MaskedArray):
            mask = np.invert(data.mask)
        else:
            mask = np.invert(np.isnan(data))

    if not mask.any():
        raise ValueError('no true values in mask')

    m, n = mask.shape
    for i in range(m):
        if mask[i,:].any():
            r0 = i
            break
    for i in range(m):
        if mask[-(i+1),:].any():
            r1 = -i
            break

    for i in range(n):
        if mask[:,i].any():
            c0 = i
            break
    for i in range(n):
        if mask[:,-(i+1)].any():
            c1 = -i
            break

    return data[slice(r0, r1 or None), slice(c0, c1 or None)]


def rotate(data, degrees, nanmask=None, cropped=True):
    """
    rotate 2D array
    :param data: 2D data array
    :param degrees: number of degrees
    :param nanmask: mask of missing values
    :param cropped: boolean to define if output should be cropped to the missing value mask
    :return: rotated 2D array
    """

    # fill masked array if required
    if isinstance(data, np.ma.MaskedArray):
        masked = True
        data = data.filled(np.nan)
    else:
        masked = False

    # calculate mean
    avg = np.mean(data[~np.isnan(data)])

    # remove mean
    data = data - avg

    # calculate min and max in dataset
    minimum = data[~np.isnan(data)].min()
    maximum = data[~np.isnan(data)].max()

    # get mask of nan locations
    if nanmask is None:
        nanmask = np.isnan(data)
    data[nanmask] = 0

    # rotate data and apply nanmask
    rotated = ndimage.rotate(data, degrees, cval=0, order=5)
    rotated_mask = ndimage.rotate(nanmask.astype(int), degrees, cval=1)
    rotated_mask = np.around(rotated_mask, decimals=0).astype(bool)
    rotated[rotated_mask] = np.nan

    rotated = np.ma.masked_invalid(rotated)

    # ensure min and max have not changed
    _valid = rotated[~rotated.mask]
    _valid[_valid < minimum] = minimum
    _valid[_valid > maximum] = maximum
    rotated[~rotated.mask] = _valid
    rotated[~rotated.mask] = _valid

    # return mean
    rotated = rotated + avg

    if not masked:
        rotated = rotated.filled(np.nan)

    if cropped:
        # crop dataset by missing values
        return crop(rotated)
    else:
        return rotated


def moving_avg(data, window=9, same_size=False):
    assert isinstance(window, int), 'window must be integer'
    assert window % 2, 'window must be uneven'
    assert window > 0, 'window must be >= 1'

    assert len(data.shape) == 1, 'data must by 1D array'
    assert data.dtype == np.float

    m, = data.shape
    m_out = m-window+1
    summed_data = np.zeros(m_out)

    for i in range(window):
        summed_data = summed_data + data[i:i+m_out]

    # TODO: NaN values

    if same_size:
        same_size_data = np.ma.masked_equal(np.zeros(data.shape), 0)
        offset = (window-1)//2
        same_size_data[offset:offset+m_out] = summed_data
        summed_data = same_size_data

        assert data.shape == summed_data.shape

    return summed_data / window


def moving_avg2(data, window=(9,9), same_size=False):
    if isinstance(window, int):
        window = (window, window)

    assert len(window) == 2, 'window must be integer or iterable of length 2'
    assert window[0] > 0 and window[1] > 0, 'window must be >= 1'

    assert len(data.shape) == 2, 'data must by 2D array'

    m, n = data.shape
    m_out, n_out = m-window[0]+1, n-window[1]+1
    summed_data = np.zeros((m_out, n_out))

    for i in range(window[0]):
        for j in range(window[1]):
            summed_data = summed_data + data[i:i+m_out, j:j+n_out]

    if same_size:
        same_size_data = np.zeros(data.shape)
        x_offset = (window[0]-1)//2
        y_offset = (window[1]-1)//2
        same_size_data[x_offset:x_offset+m_out, y_offset:y_offset+n_out] = summed_data
        summed_data = same_size_data

    return summed_data / (window[0]*window[1])


def extrema(data, window=3):
    # moving average to remove small peaks (tweak based on cellsize)
    data_avg = moving_avg(data, window=window, same_size=True)

    # calculate extrema based on changes in derivative
    extrema = np.sign(np.diff(np.sign(np.diff(data_avg))))
    indices = np.arange(1, data.shape[0])
    peaks = indices[extrema==-1]
    troughs = indices[extrema==1]

    # get max/min in window
    # 1. create an index matrix for indices to search per point
    peakrange = np.tile(peaks, (window, 1)).T + np.arange(window) - np.ceil(window/2)
    troughrange = np.tile(troughs, (window, 1)).T + np.arange(window) - np.ceil(window/2)
    # 2. get the index of the max/min in the index matrix per point
    peaks = peakrange[np.arange(peakrange.shape[0]), np.argmax(data[peakrange.astype(int)], axis=1)]
    troughs = troughrange[np.arange(troughrange.shape[0]), np.argmin(data[troughrange.astype(int)], axis=1)]

    # return peak and trough positions
    return peaks.astype(int), troughs.astype(int)


class Timer(object):

    def __init__(self):
        self.time = None
        self.elapsed = None

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.elapsed = time.time() - self.start

    def __str__(self):
        try:
            return '{:.2f}s'.format(self.elapsed)
        except ValueError:
            return repr(self)


class Chainmap(dict):

    def __init__(self, *maps):
        self._maps = maps

    def __getitem__(self, key):
        for mapping in self._maps:
            try:
                return mapping[key]
            except KeyError:
                pass
        raise KeyError(key)

    def todict(self):
        keys = set()
        for m in self._maps:
            keys.update(set(m.keys()))
        return {k:self[k] for k in keys}

    def __repr__(self):
        return repr(self.todict())


if __name__ == '__main__':
    data = np.array([[0, 2, 0, 2, 0]]*5)
    assert (moving_avg(data, 2) == np.ones((4, 4))).all()
    assert (moving_avg(data, 1) == data).all()
    assert (moving_avg(data.T, 1) == data.T).all()

    data = np.array([[1, 3, 5, 7, 9, 7, 5, 3, 1]]*5)
    assert (moving_avg(data, 3) == np.array([[3, 5, 7, (7+9+7)/3, 7, 5, 3]]*3)).all()
    assert (moving_avg(data.T, 3) == np.array([[3, 5, 7, (7+9+7)/3, 7, 5, 3]]*3).T).all()


