from pathlib import Path

import lmfit
import matplotlib.pyplot as plt
import numpy as np
from quantify_core.data.handling import (
    load_dataset_from_path,
)

calc_diamond_modes = lambda x: (2 * x + 1) * (0.637 / (4 * 2.41))
calc_air_modes = lambda x: (2 * x) * (0.637 / (4 * 2.41))


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx


def l_abs(t_d, alpha_diamond):
    return 2 * alpha_diamond * t_d


def l_scat(t_d, n_d, lambda_0, sigma_rms):
    return (
        np.sin((2 * np.pi * n_d * t_d) / lambda_0) ** 2
        * ((1 + n_d) * (n_d - 1) ** 2)
        / n_d
        * ((4 * np.pi * sigma_rms) / lambda_0) ** 2
    )


def alpha(t_d, n_d, lambda_0):
    return (
        1 / n_d * np.sin((2 * np.pi * n_d * t_d) / lambda_0) ** 2
        + n_d * np.cos((2 * np.pi * n_d * t_d) / lambda_0) ** 2
    )


def finesse(
    t_d, t_d_offset, l_fiber, l_mirror, l_add, n_d, lambda_0, sigma_rms, alpha_diamond
):
    return (
        2
        * np.pi
        / (
            l_fiber
            + 1 / alpha(t_d + t_d_offset, n_d, lambda_0) * l_mirror
            + l_add
            + 1
            / alpha(t_d + t_d_offset, n_d, lambda_0)
            * l_scat(t_d + t_d_offset, n_d, lambda_0, sigma_rms)
            + 1
            / alpha(t_d + t_d_offset, n_d, lambda_0)
            * l_abs(t_d + t_d_offset, alpha_diamond)
        )
    )


# plotting
diamond_mode_color = "blueviolet"
air_mode_color = "deepskyblue"

# from white light wavy pattern, Fig. 11 sup:
t_d_measured_pai_mei = 3.309  # um
t_d_measured_vincent_vega = 2.508  # um

# Plotting:
add_loss_low_vincent_vega, add_loss_high_vincent_vega = 420e-6, 690e-6
rms_low_vincent_vega, rms_high_vincent_vega = 0.0006, 0.0012  # um

add_loss_low_pai_mei, add_loss_high_pai_mei = 500e-6, 1040e-6
rms_low_pai_mei, rms_high_pai_mei = 0.0011, 0.0018  # um

data_vincent_vega = load_dataset_from_path(
    Path(
        "../data/20241218/20241218-043859-573-85575c-cavity_lateral_finesse_scan/dataset_processed.hdf5"
    )
)
offset_vincent_vega = (
    t_d_measured_vincent_vega - data_vincent_vega.attrs["t_d_offset_wavy_pattern"]
)
thickness_vincent_vega = (
    data_vincent_vega.diamond_thickness.values + offset_vincent_vega
)
finesse_vincent_vega = data_vincent_vega.finesse

data_pai_mei = load_dataset_from_path(
    Path(
        "../data/20250101/20250101-231412-630-520c55-cavity_lateral_finesse_scan/dataset_processed.hdf5"
    )
)
offset_pai_mei = t_d_measured_pai_mei - data_pai_mei.attrs["t_d_offset_wavy_pattern"]
thickness_pai_mei = (data_pai_mei.diamond_thickness + offset_pai_mei).values
finesse_pai_mei = data_pai_mei.finesse
mask = (thickness_pai_mei > 3.36) & (thickness_pai_mei < 3.89)
thickness_pai_mei, finesse_pai_mei = thickness_pai_mei[mask], finesse_pai_mei[mask]

print("Vincent Vega offset interferometer - whitelight: {}".format(offset_vincent_vega))

fit_finesse_vincent_vega = lmfit.Model(finesse)
fit_finesse_vincent_vega.set_param_hint(
    "t_d_offset", value=0.0, vary=True, min=-0.066, max=0.066
)
fit_finesse_vincent_vega.set_param_hint("l_fiber", value=50e-6, vary=False)
fit_finesse_vincent_vega.set_param_hint("l_mirror", value=670e-6, vary=False)
fit_finesse_vincent_vega.set_param_hint("l_add", value=550e-6)
fit_finesse_vincent_vega.set_param_hint("n_d", value=2.41, vary=False)
fit_finesse_vincent_vega.set_param_hint("lambda_0", value=0.637, vary=False)
fit_finesse_vincent_vega.set_param_hint("sigma_rms", value=0.001, min=0.0001, max=0.01)
fit_finesse_vincent_vega.set_param_hint("alpha_diamond", value=0, vary=False)
fit_result_vincent_vega = fit_finesse_vincent_vega.fit(
    finesse_vincent_vega,
    t_d=thickness_vincent_vega,
    params=fit_finesse_vincent_vega.make_params(),
)

print("Fit Report Vincent Vega:")
print(fit_result_vincent_vega.fit_report())


print("Pai-Mei offset interferometer - whitelight: {}".format(offset_pai_mei))

fit_finesse_pai_mei = lmfit.Model(finesse)
fit_finesse_pai_mei.set_param_hint(
    "t_d_offset", value=0.0, vary=True, min=-0.066, max=0.066
)
fit_finesse_pai_mei.set_param_hint("l_fiber", value=50e-6, vary=False)
fit_finesse_pai_mei.set_param_hint("l_mirror", value=670e-6, vary=False)
fit_finesse_pai_mei.set_param_hint("l_add", value=550e-6)
fit_finesse_pai_mei.set_param_hint("n_d", value=2.41, vary=False)
fit_finesse_pai_mei.set_param_hint("lambda_0", value=0.637, vary=False)
fit_finesse_pai_mei.set_param_hint("sigma_rms", value=0.001, min=0.0001, max=0.01)
fit_finesse_pai_mei.set_param_hint("alpha_diamond", value=0, vary=False)
fit_result_pai_mei = fit_finesse_pai_mei.fit(
    finesse_pai_mei, t_d=thickness_pai_mei, params=fit_finesse_pai_mei.make_params()
)

print("Fit Report Pai Mei:")
print(fit_result_pai_mei.fit_report())

plot_td_vincent_vega = np.linspace(
    np.min(thickness_vincent_vega), np.max(thickness_vincent_vega), 1000
)
plot_td_pai_mei = np.linspace(
    np.min(thickness_pai_mei), np.max(thickness_pai_mei), 1000
)

plot_vincent_vega_td_offset = fit_result_vincent_vega.params["t_d_offset"].value
plot_pai_mei_td_offset = fit_result_pai_mei.params["t_d_offset"].value

fig, (ax_1, ax_2) = plt.subplots(2, 1, figsize=(8, 6), gridspec_kw={"hspace": 0.15})
ax_1.text(-0.13, 1.05, "(a)", transform=ax_1.transAxes, fontsize=12, va="top")
ax_1.scatter(
    thickness_vincent_vega + plot_vincent_vega_td_offset,
    finesse_vincent_vega,
    s=0.05,
    zorder=2,
    color=plt.colormaps.get_cmap("viridis")(60),
)
ax_1.plot(
    plot_td_vincent_vega + plot_vincent_vega_td_offset,
    finesse(plot_td_vincent_vega, **fit_result_vincent_vega.best_values),
    color="black",
    zorder=4,
    label=r"$\mathcal{L}_{\rm add}$="
    + "{:.0f} ppm,".format(round(1e6 * fit_result_vincent_vega.params["l_add"], -1))
    + r" $\sigma_{DA}$="
    + "{:.1f} nm".format(1e3 * fit_result_vincent_vega.params["sigma_rms"]),
)
ax_1.plot(
    plot_td_vincent_vega + plot_vincent_vega_td_offset,
    finesse(
        plot_td_vincent_vega,
        fit_result_vincent_vega.params["t_d_offset"].value,
        50e-6,
        670e-6,
        add_loss_low_vincent_vega,
        2.41,
        0.637,
        rms_low_vincent_vega,
        0.0,
    ),
    color=plt.colormaps.get_cmap("viridis")(280),
    label=r"$\mathcal{L}_{\rm add}$="
    + "{:.0f} ppm,".format(round(1e6 * add_loss_low_vincent_vega, -1))
    + r" $\sigma_{DA}$="
    + "{:.1f} nm".format(1e3 * rms_low_vincent_vega),
)
ax_1.plot(
    plot_td_vincent_vega + plot_vincent_vega_td_offset,
    finesse(
        plot_td_vincent_vega,
        fit_result_vincent_vega.params["t_d_offset"].value,
        50e-6,
        670e-6,
        add_loss_high_vincent_vega,
        2.41,
        0.637,
        rms_high_vincent_vega,
        0.0,
    ),
    color=plt.colormaps.get_cmap("viridis")(240),
    label=r"$\mathcal{L}_{\rm add}$="
    + "{:.0f} ppm,".format(round(1e6 * add_loss_high_vincent_vega, -1))
    + r" $\sigma_{DA}$="
    + "{:.1f} nm".format(1e3 * rms_high_vincent_vega),
)

# ax_1.set_xlim(np.min(thickness_vincent_vega),np.max(thickness_vincent_vega))
ax_1.set_xlim(2.4, 2.75)
ax_1.set_ylim(0, 10000)
ax_1.set_ylabel(r"Finesse $\mathcal{F}$")
ax_1.legend(loc="lower right", fontsize=8)

ticks = [
    calc_diamond_modes(18),
    calc_air_modes(19),
    calc_diamond_modes(19),
    calc_air_modes(20),
    calc_diamond_modes(20),
]
ax_1.set_xticks([round(t, 3) for t in ticks])
ax_1.get_xticklabels()[1].set_color(air_mode_color)
ax_1.get_xticklabels()[3].set_color(air_mode_color)
ax_1.get_xticklabels()[0].set_color(diamond_mode_color)
ax_1.get_xticklabels()[2].set_color(diamond_mode_color)
ax_1.get_xticklabels()[4].set_color(diamond_mode_color)

ax_2.text(-0.13, 1.05, "(b)", transform=ax_2.transAxes, fontsize=12, va="top")
ax_2.scatter(
    thickness_pai_mei + plot_pai_mei_td_offset,
    finesse_pai_mei,
    s=0.05,
    zorder=2,
    color=plt.colormaps.get_cmap("viridis")(60),
)
ax_2.plot(
    plot_td_pai_mei + plot_pai_mei_td_offset,
    finesse(plot_td_pai_mei, **fit_result_pai_mei.best_values),
    color="black",
    zorder=4,
    label=r"$\mathcal{L}_{\rm add}$="
    + "{:.0f} ppm,".format(round(1e6 * fit_result_pai_mei.params["l_add"], -1))
    + r" $\sigma_{DA}$="
    + "{:.1f} nm".format(1e3 * fit_result_pai_mei.params["sigma_rms"]),
)
ax_2.plot(
    plot_td_pai_mei + plot_pai_mei_td_offset,
    finesse(
        plot_td_pai_mei,
        fit_result_pai_mei.params["t_d_offset"].value,
        50e-6,
        670e-6,
        add_loss_low_pai_mei,
        2.41,
        0.637,
        rms_low_pai_mei,
        0.0,
    ),
    color=plt.colormaps.get_cmap("viridis")(280),
    label=r"$\mathcal{L}_{\rm add}$="
    + "{:.0f} ppm,".format(round(1e6 * add_loss_low_pai_mei, -1))
    + r" $\sigma_{DA}$="
    + "{:.1f} nm".format(1e3 * rms_low_pai_mei),
)
ax_2.plot(
    plot_td_pai_mei + plot_pai_mei_td_offset,
    finesse(
        plot_td_pai_mei,
        fit_result_pai_mei.params["t_d_offset"].value,
        50e-6,
        670e-6,
        add_loss_high_pai_mei,
        2.41,
        0.637,
        rms_high_pai_mei,
        0.0,
    ),
    color=plt.colormaps.get_cmap("viridis")(240),
    label=r"$\mathcal{L}_{\rm add}$="
    + "{:.0f} ppm,".format(round(1e6 * add_loss_high_pai_mei, -1))
    + r" $\sigma_{DA}$="
    + "{:.1f} nm".format(1e3 * rms_high_pai_mei),
)
ax_2.set_xlim(3.36, 3.89)
ax_2.set_ylim(0, 10000)
ax_2.set_xlabel(r"Diamond thickness $t_d$" + " (µm)")
ax_2.set_ylabel(r"Finesse $\mathcal{F}$")
ax_2.legend(loc="lower right", fontsize=8)

ticks_2 = [
    calc_diamond_modes(25),
    calc_air_modes(26),
    calc_diamond_modes(26),
    calc_air_modes(27),
    calc_diamond_modes(27),
    calc_air_modes(28),
    calc_diamond_modes(28),
    calc_air_modes(29),
]
ax_2.set_xticks([round(t, 3) for t in ticks_2])
ax_2.get_xticklabels()[1].set_color(air_mode_color)
ax_2.get_xticklabels()[3].set_color(air_mode_color)
ax_2.get_xticklabels()[5].set_color(air_mode_color)
ax_2.get_xticklabels()[7].set_color(air_mode_color)
ax_2.get_xticklabels()[0].set_color(diamond_mode_color)
ax_2.get_xticklabels()[2].set_color(diamond_mode_color)
ax_2.get_xticklabels()[4].set_color(diamond_mode_color)
ax_2.get_xticklabels()[6].set_color(diamond_mode_color)

count_vincent_vega = 0
for i in range(len(thickness_vincent_vega)):
    f_high = finesse(
        thickness_vincent_vega[i],
        fit_result_vincent_vega.params["t_d_offset"].value,
        50e-6,
        670e-6,
        add_loss_high_vincent_vega,
        2.41,
        0.637,
        rms_high_vincent_vega,
        0.0,
    )
    f_low = finesse(
        thickness_vincent_vega[i],
        fit_result_vincent_vega.params["t_d_offset"].value,
        50e-6,
        670e-6,
        add_loss_low_vincent_vega,
        2.41,
        0.637,
        rms_low_vincent_vega,
        0.0,
    )
    if (finesse_vincent_vega[i] >= f_high) and (finesse_vincent_vega[i] <= f_low):
        count_vincent_vega += 1

print("Count Vincent Vega: {}".format(count_vincent_vega / len(thickness_vincent_vega)))

count_pai_mei = 0
for i in range(len(thickness_pai_mei)):
    f_high = finesse(
        thickness_pai_mei[i],
        fit_result_pai_mei.params["t_d_offset"].value,
        50e-6,
        670e-6,
        add_loss_high_pai_mei,
        2.41,
        0.637,
        rms_high_pai_mei,
        0.0,
    )
    f_low = finesse(
        thickness_pai_mei[i],
        fit_result_pai_mei.params["t_d_offset"].value,
        50e-6,
        670e-6,
        add_loss_low_pai_mei,
        2.41,
        0.637,
        rms_low_pai_mei,
        0.0,
    )
    if (finesse_pai_mei[i] >= f_high) and (finesse_pai_mei[i] <= f_low):
        count_pai_mei += 1

print("Count Pai-Mei: {}".format(count_pai_mei / len(thickness_pai_mei)))

print(
    "Maximal scattering losses Vincent Vega (0.6 nm): {:.0f} ppm".format(
        np.max(
            1e6
            * l_scat(
                (thickness_vincent_vega + plot_vincent_vega_td_offset),
                2.41,
                0.637,
                0.6e-3,
            )
        )
    )
)
print(
    "Maximal scattering losses Pai Mei (1.1 nm): {:.0f} ppm".format(
        np.max(
            1e6
            * l_scat((plot_td_pai_mei + plot_pai_mei_td_offset), 2.41, 0.637, 1.1e-3)
        )
    )
)

fig.savefig(Path("../Fig_4.png"), dpi=600, bbox_inches="tight")
