import numpy as np
from scipy.signal import correlate2d
from scipy.interpolate import RectBivariateSpline
from .base import WindowTooSmall


def _peak(sx, sy, corr, upsample_factor=1):
    if not isinstance(upsample_factor, int):
        raise TypeError('upsample factor must be an integer')


    I = np.unravel_index(np.argmax(corr), corr.shape)
    p_i, p_j = I

    shift_x, shift_y = sx[p_j], sy[p_i]

    if upsample_factor > 1:
        interpfn = RectBivariateSpline(sx, sy, corr)

        # upsample sx and sy from peak-1 to peak+1 with length upsample_factor+2
        upsampled_sx = np.linspace(-1, 1, upsample_factor+2)+shift_x
        upsampled_sy = np.linspace(-1, 1, upsample_factor+2)+shift_y

        _sX, _sY = np.meshgrid(upsampled_sx, upsampled_sy)
        upsampled_xcorr = interpfn.ev(_sX.flatten(), _sY.flatten())\
            .reshape(upsampled_sy.size, upsampled_sx.size)

        # recalculate the peak (recursive)
        shift_x, shift_y = _peak(upsampled_sx, upsampled_sy, upsampled_xcorr, upsample_factor=1)

    return shift_x, shift_y


def calculate_shift(A, B, upsample_factor=100, **kwargs):
    if A.shape != B.shape:
        raise ValueError()

    m, n = A.shape
    sx = np.arange(-m//2+1, m//2+1)
    sy = np.arange(-n//2+1, n//2+1)
    xcorr = correlate2d(A, B, mode='same')

    spx, spy = _peak(sx, sy, xcorr, upsample_factor=upsample_factor)

    return np.array([spx, spy]), 0