import numpy as np

from scipy import constants
import qcodes as qc
configuration = qc.config
from xarray import DataArray, merge
from resonator import background, shunt

from qcodes.dataset.data_set import load_by_guid

from uncertainties import ufloat
import uncertainties.umath as umath
from uncertainties.umath import *
from uncertainties import unumpy

import model_mwimpedance as mwi
import model_cpwresonator as cpwr

def db2xarray_guid(guid, **kwargs):

    """
    Take a dataset from a qcodes database identified by its ID and transform it into a xarray.Dataset
    Wraps around the function load_by_run_spec, which allows to get data from different databases, if you supply
    the corresponding connection. The snapshot of the dataset is available via the attrs attribute of the returned
    xarray.Dataset.

    @param ind: index of the dataset in the QCoDeS database you want to transform to a XArray
    @param kwargs: kwargs to be passed to the underlying qcodes function load_by_run_spec
    @return: xarray.Dataset with the independent parameters as coordinates,
     and the dependent parameters as Data variables
    """

    d = load_by_guid(guid=guid, **kwargs)
    ds = d.to_xarray_dataset()
    return ds


def resonator_fitter(freq, amp, phase, **kwds):
    '''
    Function wraps around the Linear Shunt Fitter of the resonator fitter by Daniel Flanigan.
    It takes three parameters. It calculates the complex field from amplitude and phase.
    Try to fit the data using a general background model. In case the fit fails, return the fit_error is set to True.

    @param freq (float):
    @param amp (float):
    @param phase (float):
    @return :

    '''

    d_complex = amp * np.exp(1j * phase)
    fit_error = False

    try:
        r = shunt.LinearShuntFitter(frequency=freq,
                                data=d_complex,
                                background_model=background.MagnitudeSlopeOffsetPhaseDelay(), **kwds)
    except:
        fit_error = True
        r = 0

    return r, fit_error


def data_trunc(freq, amp, phase, state=False, background=None):
    '''
    Returns estimate for resonance frequency to help resonator_fitter function.

    Keyword arguments:
    freq -- array of frequency points
    amp -- array of amplitude data
    phase -- array of phase data
    state -- selection of one or the other resonator finder
    background -- load background trace to help in case of non-flat background
    '''
    if state == True:
        if background is not None:
            amp_corr = amp - background
            for k in range(len(amp_corr)):
                if amp_corr[k] > 0:
                    amp_corr[k] = 0
            ### FFT
            fft = np.fft.rfft(amp_corr)
            ### band pass filter
            fft[0:5] = 0 + 1j * 0
            fft[240:] = 0 + 1j * 0
            ### iFFT
            ifft = np.fft.irfft(fft)
            ### find resonance
            cen = abs(ifft).argmax()
        else:
            cen = np.argmin(amp)
        mi = int(cen - 75)
        ma = int(cen + 75)
        if mi <= 0:
            mi = int(0)
        elif ma >= freq.size:
            ma = int(freq.size)
        freq = freq[mi:ma]
        amp = amp[mi:ma]
        phase = phase[mi:ma]

    return freq, amp, phase


def resonator_multiscan_multires(xarray,
                                 fit_axis1=None, fit_axis2=None,
                                 resfinder=False, background=None, samplelabel=None, reslabel=None, **kwds):
    '''
    Function wraps around resonator_fitter function for up to 2-dimensional scans and returns fit results as xarray

    Keyword arguments:
    xarray -- raw data from database
    fit_axis1 -- control parameter 1
    fit_axis2 -- control parameter 2
    resfinder -- boolean for resonator finder in case of non-flat background
    background -- load background trace to help in case of non-flat background
    samplelabel -- sample label will be written to xarray attributes
    reslabel -- resonator label will be written to xarray attributes
    kwds -- more keyword arguments for resonator_fitter
    '''

    d = xarray
    d.attrs["sample"] = "s" + str(samplelabel)
    d.attrs["resonator"] = reslabel

    _dims_dict = {'Magnitude': 5, 'MagnitudePhase': 6, 'MagnitudePhaseDelay': 8, 'MagnitudeSlopeOffsetPhaseDelay': 9}

    if type(fit_axis1) != str:
        for k in range(len(d.dims)):
            fitresults_flanigan = np.zeros(
                (1,
                 _dims_dict['MagnitudeSlopeOffsetPhaseDelay'] + 7))

            if len(d.dims) == 2:  # single resonator fit
                fap = data_trunc(d.frequency.values, d.amplitude.values, d.phase.values,
                                 state=resfinder,
                                 background=background)

            else:
                fap = data_trunc(eval("d.frequency" + str(k)).values, eval("d.amplitude" + str(k)).values,
                                 eval("d.phase" + str(k)).values, state=resfinder, background=background)

            r, fit_error = resonator_fitter(fap[0], fap[1], fap[2], **kwds)

            if r.f_r_error == None or r.f_r_error >= r.f_r * 0.005 or r.Q_i_error >= r.Q_i * 0.28:
                an_array = np.empty((1, 16))
                an_array[:] = np.NaN
                fitresults_flanigan[0, :] = an_array

            else:
                fitresults_flanigan[0, :] = list(r.result.values.values()) + [
                    r.result.params['resonance_frequency'].stderr,
                    r.result.params['internal_loss'].stderr,
                    r.result.params['coupling_loss'].stderr,
                    r.Q_c, r.Q_c_error, r.Q_i, r.Q_i_error]

            _fxB = []
            flan_names = list(r.result.values.keys()) \
                         + ['resonance_frequency_err'] \
                         + ['internal_loss_err'] \
                         + ['coupling_loss_err'] \
                         + ['Q_c'] + ['Q_c_err'] + ['Q_i'] + ['Q_i_err']
            for ind, var_name in enumerate(flan_names):
                _fxB += [DataArray(fitresults_flanigan[:, ind],
                                   name='fit_' + var_name + str(k),
                                   )]

            d = merge([d, *_fxB], combine_attrs="drop_conflicts")

    elif type(fit_axis1) == str and type(fit_axis2) != str:
        for k in range(len(d.dims) - 1):
            fitresults_flanigan = np.zeros(
                (getattr(d, fit_axis1).shape[0],
                 _dims_dict['MagnitudeSlopeOffsetPhaseDelay'] + 7))

            for i in range(getattr(d, fit_axis1).shape[0]):

                if len(d.dims) == 2:  # single resonator fit
                    fap = data_trunc(d.frequency.values, d.amplitude.values[i], d.phase.values[i],
                                     state=resfinder,
                                     background=background)

                else:
                    fap = data_trunc(eval("d.frequency" + str(k)).values, eval("d.amplitude" + str(k)).values[i],
                                     eval("d.phase" + str(k)).values[i], state=resfinder, background=background)

                r, fit_error = resonator_fitter(fap[0], fap[1], fap[2], **kwds)

                if r.f_r_error == None or r.f_r_error >= r.f_r * 0.005 or r.Q_i_error >= r.Q_i * 0.28:
                    an_array = np.empty((1, 16))
                    an_array[:] = np.NaN
                    fitresults_flanigan[i, :] = an_array

                else:
                    fitresults_flanigan[i, :] = list(r.result.values.values()) + [
                        r.result.params['resonance_frequency'].stderr,
                        r.result.params['internal_loss'].stderr,
                        r.result.params['coupling_loss'].stderr,
                        r.Q_c, r.Q_c_error, r.Q_i, r.Q_i_error]

            _fxB = []
            flan_names = list(r.result.values.keys()) \
                         + ['resonance_frequency_err'] \
                         + ['internal_loss_err'] \
                         + ['coupling_loss_err'] \
                         + ['Q_c'] + ['Q_c_err'] + ['Q_i'] + ['Q_i_err']
            for ind, var_name in enumerate(flan_names):
                _fxB += [DataArray(fitresults_flanigan[:, ind],
                                   name='fit_' + var_name + str(k),
                                   coords={fit_axis1: getattr(d, fit_axis1).values},
                                   dims=[fit_axis1])]

            d = merge([d, *_fxB], combine_attrs="drop_conflicts")


    elif type(fit_axis1) == str and type(fit_axis2) == str: # 2D fit

        for k in range(len(d.dims) - 2):
            fitresults_flanigan = np.zeros(
                (getattr(d, fit_axis1).shape[0], getattr(d, fit_axis2).shape[0],
                 _dims_dict['MagnitudeSlopeOffsetPhaseDelay'] + 7))

            for i in range(getattr(d, fit_axis1).shape[0]):
                for j in range(getattr(d, fit_axis2).shape[0]):

                    if len(d.dims) == 3:  # single resonator fit
                        fap = data_trunc(d.frequency.values, d.amplitude.values[i, j], d.phase.values[i, j],
                                         state=resfinder,
                                         background=background)

                    else:
                        fap = data_trunc(eval("d.frequency" + str(k)).values, eval("d.amplitude" + str(k)).values[i, j],
                                         eval("d.phase" + str(k)).values[i, j], state=resfinder, background=background)

                    r, fit_error = resonator_fitter(fap[0], fap[1], fap[2], **kwds)

                    if r.f_r_error == None or r.f_r_error >= r.f_r * 0.005 or r.Q_i_error >= r.Q_i * 0.28:
                        an_array = np.empty((1, 16))
                        an_array[:] = np.NaN
                        fitresults_flanigan[i,j, :] = an_array

                    else:
                        fitresults_flanigan[i,j, :] = list(r.result.values.values()) + [
                            r.result.params['resonance_frequency'].stderr,
                            r.result.params['internal_loss'].stderr,
                            r.result.params['coupling_loss'].stderr,
                            r.Q_c, r.Q_c_error, r.Q_i, r.Q_i_error]

            _fxB = []
            flan_names = list(r.result.values.keys()) \
                         + ['resonance_frequency_err'] \
                         + ['internal_loss_err'] \
                         + ['coupling_loss_err'] \
                         + ['Q_c'] + ['Q_c_err'] + ['Q_i'] + ['Q_i_err']
            for ind, var_name in enumerate(flan_names):
                _fxB += [DataArray(fitresults_flanigan[:, :, ind],
                                   name='fit_' + var_name + str(k),
                                   coords={fit_axis1: getattr(d, fit_axis1).values,
                                           fit_axis2: getattr(d, fit_axis2).values},
                                   dims=[fit_axis1, fit_axis2])]

            d = merge([d, *_fxB], combine_attrs="drop_conflicts")

    return d




def impendance_nw(results, label, qi_ref, axis, resonators):
    '''
        Returns nanowire inductance and resistance based on Eq. 9 in the Supplement.
        First, it computes the bare resonator frequency from the resonator parameters.

        Keyword arguments:
        results -- xarray with results from resonator fitting
        label -- resonator label to assign resonator properties
        qi_ref -- internal qualtiy factor of reference resonator
        axis -- control parameter, like gate voltage, temperature, ...
        resonator -- dictionary of resonator properties
        ind_NbTiN_measured -- kinetic sheet inductance of NbTiN film
        '''

    for i in range(len(results)):

        for j in range(len(label)):

            # define array
            impnw = np.zeros((3, len(eval('results[' + str(i) + '] ').coords[axis].values)), dtype=object)

            # from fit
            freqnw = eval('results[' + str(i) + ' ].fit_resonance_frequency' + str(j) + '.values')
            qinw = eval('results[' + str(i) + ' ].fit_Q_i' + str(j) + '.values')

            # from resonator parameters
            freq0 = resonators["nwr"][label[j]]["freq0_measured"]
            imp0 = resonators["nwr"][label[j]]["imp0_measured"]
            qi0 = qi_ref[j]

            ind_nw = imp0/4*(1/freqnw-1/freq0)

            # calculate impedance
            impnw[0] = unumpy.nominal_values(ind_nw)
            impnw[1] = unumpy.std_devs(ind_nw)
            impnw[2] = mwi.res_nw_cpw(qinw, qi0, imp0)

            # write to xarray
            name_list = ['nanowire_inductance', 'nanowire_inductance_err', 'nanowire_resistance']
            _fxB = []
            for ind, var_name in enumerate(name_list):
                _fxB += [DataArray(impnw[ind],
                                   name=var_name + str(j),
                                   coords={axis: getattr(results[i], axis).values},
                                   dims=[axis])]

            results[i] = merge([results[i], *_fxB], combine_attrs='drop_conflicts', compat='override')

    return results

def nanowire_inductance(freqs, width, spacing, length, ind_k):
    '''
    Returns nanowire inductance based on Eq. 9 in the Supplement.
    First, it computes the bare resonator frequency from the resonator parameters.

    Keyword arguments:
    freqs -- resonance frequency of nanowire resonator
    width -- width of central conductor
    spacing -- spacing of central conductor to ground
    length -- length of central conductor
    ind_k -- kinetic sheet inductance of NbTiN film
    '''
    f0, imp0 = cpwr.cpw_bare_resonator(width=width*1e-6,
                       spacing=spacing*1e-6,
                       length=length*1e-6,
                       ind_kin_sq=ind_k*1e-12,
                       epsilon_r=11.7)

    return imp0/4*(1/freqs-1/f0)

def temp2ind(temp,delta0,s0):
    '''
    Returns inductance of superconductor based on Mattis-Bardeen formula. See Tinkham Eq. 3.125

    Keyword arguments:
    temp -- temperature
    delta0 -- superconducting gap at zero temperature
    s0 -- normal conductance
    '''
    length=3
    return (constants.hbar * length) / (np.pi * s0 * delta0) / np.tanh(delta0 / (2 * constants.k * temp))

def linear(x, m, b):
    '''Returns linear function'''
    return m*x+b