import numpy as np
from . import interp


def extend_docstring(fn):
    def decorator(fn2):
        if fn.__doc__ is None:
            pass
        elif fn2.__doc__ is None:
            fn2.__doc__ = fn.__doc__
        else:
            fn2.__doc__ += '\n'+fn.__doc__
        return fn2
    return decorator


def rolling_window_lastaxis(a, window):
    """Directly taken from Erik Rigtorp's post to numpy-discussion.
    <http://www.mail-archive.com/numpy-discussion@scipy.org/msg29450.html>"""
    if window < 1:
       raise ValueError("window must be at least 1.")
    if window > a.shape[-1]:
       raise ValueError("window is too long.")
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)


def rolling_window(a, window):
    if not hasattr(window, '__iter__'):
        return rolling_window_lastaxis(a, window)
    for i, win in enumerate(window):
        if win > 1:
            a = a.swapaxes(i, -1)
            a = rolling_window_lastaxis(a, win)
            a = a.swapaxes(-2, i)
    return a


def buffer_mask(m, invert=False):
    buffered = rolling_window(m, (3, 3)).max(axis=-1).max(axis=-1)
    if invert:
        buffered = ~buffered

    out = m.copy()
    out[1:-1, 1:-1] = buffered

    return out


def smooth_mask(m, invert=False):
    shape = m.shape
    smoothed = np.sum(rolling_window(m, (3, 3)).reshape(shape[0]-2, shape[1]-2, -1), axis=-1) > 4
    if invert:
        smoothed = ~smoothed

    smoothed = smoothed | m[1:-1, 1:-1]

    out = m.copy()
    out[1:-1, 1:-1] = smoothed

    return out


def unitrange(vmin, vmax, N=10):
    diff = float(vmax - vmin)
    _step = diff / N
    magn = np.log10(_step)
    unit = 10**round(magn-1)
    step = round(_step/unit)*unit
    return np.arange(vmin, vmax+step, step)


if __name__ == '__main__':
    print(unitrange(0.4, 0.6, N=4))
