"""
In this code we do the analysis and visualisation for the figures which combine observations and DALES. The code-file is a part of the publication:

"Turbulent Exchange of CO2 in the Lower Tropical Troposphere Across Clear-to-Cloudy Conditions"

Author: Vincent S. de Feiter. Contact: vincent.defeiter@wur.nl

Version: 29 April, 2025
"""
#%%
#------------------------------------------------------------------------------------------------------------------------------------------------------
#                                                                   I M P O R T   L I B R A R I E S 
#------------------------------------------------------------------------------------------------------------------------------------------------------
import requests
from bs4 import BeautifulSoup
from datetime import datetime, timedelta
import os
import xarray as xr
import numpy as np
import metpy.calc as mpcalc
from metpy.units import units
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.transforms as transforms
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import AutoMinorLocator
import pandas as pd
import matplotlib.colors as mcolors
import os
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.ticker import MaxNLocator
import matplotlib.image as mpimg
from scipy import stats
from matplotlib import patheffects
import statsmodels.api as sm
from sklearn.metrics import r2_score, mean_squared_error
from matplotlib.patches import Patch
import matplotlib.patheffects as PathEffects
import json

#Maps
import contextily as cx
import pyproj
import matplotlib.image as mpimg
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from scipy.ndimage import rotate
import matplotlib
import matplotlib.patches as patches
from matplotlib.colors import LinearSegmentedColormap, Normalize
import matplotlib.patches as patches
import matplotlib.patheffects as pe
import cartopy
import cartopy.geodesic as cgeo
import cartopy.crs as ccrs

import cartopy.io.img_tiles as cimgt
import io
from urllib.request import urlopen, Request
from PIL import Image
import shapely
from owslib.wmts import WebMapTileService
from cartopy.io.img_tiles import OSM
from scipy.interpolate import griddata
import matplotlib.ticker as mticker


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)
    
#Use LATEX font and set figure style
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    "text.latex.preamble": r"\usepackage{amsmath}"
})


fontsize = 15
plt.rcParams['font.size'] = fontsize

# # Update the font sizes for ylabel, xlabel, and tick labels
plt.rcParams['axes.labelsize'] = fontsize   # X and Y labels

# # Optionally, you can scale other elements, such as legend size, here
plt.rcParams['legend.fontsize'] = fontsize
plt.rcParams['axes.titlepad'] = 20   # Title padding
plt.rcParams['axes.labelpad'] = 20   # Axis label padding
plt.rcParams['xtick.major.pad']= 5 
plt.rcParams['ytick.major.pad']= 5

plt.rcParams['axes.edgecolor'] = 'k'  # Color of the plot border
plt.rcParams['axes.linewidth'] = 1  # Width of the plot border (spines)
plt.rcParams['xtick.major.width'] = 1  # Width of major ticks on the x-axis
plt.rcParams['ytick.major.width'] = 1  # Width of major ticks on the y-axis
plt.rcParams['xtick.minor.width'] = 1  # Width of minor ticks on the x-axis
plt.rcParams['ytick.minor.width'] = 1  # Width of minor ticks on the y-axis
plt.rcParams['xtick.major.size'] = 6  # Length of major ticks on the x-axis
plt.rcParams['ytick.major.size'] = 6  # Length of major ticks on the y-axis
plt.rcParams['xtick.minor.size'] = 4   # Length of minor ticks on the x-axis
plt.rcParams['ytick.minor.size'] = 4   # Length of minor ticks on the y-axis

axiswidth = 0.8


def get_square_bounds(lat_min, lat_max, lon_min, lon_max):
    
    lat_min = lat_min - 0.01
    lat_max = lat_max + 0.01
    
    # Calculate center point
    center_lat = (lat_min + lat_max) / 2
    center_lon = (lon_min + lon_max) / 2
    
    # Calculate ranges in latitude and longitude
    lat_range = lat_max - lat_min
    lon_range = lon_max - lon_min
    
    # Determine the radius of the square as half of the largest range
    square_radius = max(lat_range, lon_range) / 2
    
    # Calculate new square boundaries
    new_lat_min = center_lat - square_radius
    new_lat_max = center_lat + square_radius
    new_lon_min = center_lon - square_radius
    new_lon_max = center_lon + square_radius
    
    return new_lat_min, new_lat_max, new_lon_min, new_lon_max



#%%
#------------------------------------------------------------------------------------------------------------------------------------------------------
#                                                           D E F I N E    P A T H S   &   D I R S    
#------------------------------------------------------------------------------------------------------------------------------------------------------

#Set root directory
root = f'C:/Users/...'

#Location where CO2 data is stored
directory_co2_data = root+"/CloudRoots/Data/Tower Data..."

#Location where flight data is stored
directory_flight = root+"/CloudRoots/Data/..."

directory_flight_detrended = root+'/CloudRoots/Data/Aircraft/...'

#Location where the tower data dataframes are stored
tower_data_rootdir = root+"/CloudRoots/Data/..."

#Location where boundary layer data is stored
boundary_layer_data_root = root+'/CloudRoots/Data/RAW/...'

#Location where vertical profile data are stored
rootdir_vertical = root+'/CloudRoots/Data/...'

#Location where cloud data are stored
rootdir_clouds = root+'/CloudRoots/Data/Boundary Layer and Cloud Characteristics/...'

#Location where DALES files are stored
rootdir_DALES = root+'/LES/DALES/CloudRoots/.../'

#Location where cloud information is stored
cloud_info = root+"/CloudRoots/Data/Boundary Layer and Cloud Characteristics/..."

#Location where soil data is stored
rootdir_soil = root+'/CloudRoots/Data/RAW/Soil/'

#Output directory
outdir = root+"/Publications/Figures"

#%%
#------------------------------------------------------------------------------------------------------------------------------------------------------
#                                                      S E T T I N G S   &   G E N E R A L  F U N C T I O N S 
#------------------------------------------------------------------------------------------------------------------------------------------------------
zmax = 3500                    #height of the vertical profiles in meters

#Location ATTO
x_ATTO = -59.0055
y_ATTO = -2.1458

#Reference time arrays
SCu_list=['2022-08-09','2022-08-10','2022-08-11','2022-08-15','2022-08-17','2022-08-18']
DeepCu_list=['2022-08-12','2022-08-13','2022-08-14','2022-08-16']
date_list_general = ['2022-08-09','2022-08-10','2022-08-11','2022-08-12','2022-08-13','2022-08-14','2022-08-15','2022-08-16','2022-08-17','2022-08-18']

x_AGS_s = pd.date_range(start="06:00", end="18:00", freq='1min').time.astype(str)
x_OBS_1_s = pd.date_range(start="06:00", end="18:00", freq='1min').time.astype(str)
x_OBS_5_s = pd.date_range(start="06:00", end="18:00", freq='5min').time.astype(str)
x_OBS_30_s = pd.date_range(start="06:00", end="18:00", freq='30min').time.astype(str)
x_OBS_s_s = pd.date_range(start="06:00", end="18:00", freq='180min').time.astype(str)
x_OBS_10_s = pd.date_range(start="06:00", end="18:00", freq='10min').time.astype(str)
x_OBS_15_s = pd.date_range(start="06:00", end="18:00", freq='15min').time.astype(str)
x_DALES_s = pd.date_range(start="06:05", end="18:00", freq='5min').time.astype(str)
x_IFS_1_s = pd.date_range(start="06:00", end="18:00", freq='1h').time.astype(str)
x_IFS_3_s = pd.date_range(start="06:00", end="18:00", freq='3h').time.astype(str)

x_AGS = [datetime.strptime(time_str, '%H:%M:%S').hour + datetime.strptime(time_str, '%H:%M:%S').minute / 60 for time_str in x_AGS_s]
x_OBS_1 = [datetime.strptime(time_str, '%H:%M:%S').hour + datetime.strptime(time_str, '%H:%M:%S').minute / 60 for time_str in x_OBS_1_s]
x_OBS_5 = [datetime.strptime(time_str, '%H:%M:%S').hour + datetime.strptime(time_str, '%H:%M:%S').minute / 60 for time_str in x_OBS_5_s]
x_OBS_30 = [datetime.strptime(time_str, '%H:%M:%S').hour + datetime.strptime(time_str, '%H:%M:%S').minute / 60 for time_str in x_OBS_30_s]
x_OBS_s = [datetime.strptime(time_str, '%H:%M:%S').hour + datetime.strptime(time_str, '%H:%M:%S').minute / 60 for time_str in x_OBS_s_s]
x_OBS_10 = [datetime.strptime(time_str, '%H:%M:%S').hour + datetime.strptime(time_str, '%H:%M:%S').minute / 60 for time_str in x_OBS_10_s]
x_OBS_15 = [datetime.strptime(time_str, '%H:%M:%S').hour + datetime.strptime(time_str, '%H:%M:%S').minute / 60 for time_str in x_OBS_15_s]
x_DALES =  [datetime.strptime(time_str, '%H:%M:%S').hour + datetime.strptime(time_str, '%H:%M:%S').minute / 60 for time_str in x_DALES_s]
x_IFS_1 =  [datetime.strptime(time_str, '%H:%M:%S').hour + datetime.strptime(time_str, '%H:%M:%S').minute / 60 for time_str in x_IFS_1_s]
x_IFS_3 = [datetime.strptime(time_str, '%H:%M:%S').hour + datetime.strptime(time_str, '%H:%M:%S').minute / 60 for time_str in x_IFS_3_s]

x_OBS_3 = x_IFS_3 
x_OBS_3_s = x_IFS_3_s 

def daily_range(date_range,freq):

    # Generate time intervals between 6:00 and 18:00 for each day
    time_intervals = pd.date_range(start="06:00", end="18:00", freq=freq).time

    # Create a list to hold all the datetime combinations
    all_times = []

    # Combine the dates with the time intervals
    for date in date_range:
        for time in time_intervals:
            all_times.append(pd.Timestamp.combine(date, time))

    # Convert to a pandas DatetimeIndex
    result = pd.DatetimeIndex(all_times)
    
    return result

date_range = pd.date_range(start='2022-08-09', end='2022-08-18', freq='D')

x_OBS_30_daily = daily_range(date_range,'30T')


#Specify heights
ATTO_levels = [43, 100, 127, 151, 172, 196, 223, 247, 298] #dropped 316
INSTANT_levels = [5, 15, 25, 35, 50, 81]
heights_CO2 = [4,24, 38, 53, 79, 321]

ATTO_levels = np.array(ATTO_levels)
INSTANT_levels = np.array(INSTANT_levels)

all_heights = np.sort(np.concatenate((ATTO_levels, INSTANT_levels), axis=0))

#General Functions
def find_closest(arr, val):
       idx = np.abs(arr - val).argmin()
       return idx

#Remove NAN out of the data
def filter_nan(s, o):
    """
    this functions removed the data  from simulated and observed data
    whereever the observed data contains nan

    this is used by all other functions, otherwise they will produce nan as
    output
    """
    if np.sum(~np.isnan(s * o)) >= 1:
        data = np.array([s.flatten(), o.flatten()])
        data = np.transpose(data)
        data = data[~np.isnan(data).any(1)]
        s = data[:, 0]
        o = data[:, 1]
    return s, o

#Calculate Index of Agreement
def index_agreement(s, o):
    """
	index of agreement

	Willmott (1981, 1982)
	input:
        s: simulated
        o: observed
    output:
        ia: index of agreement
    """
    s, o = filter_nan(s, o)
    ia = 1 - (np.sum((o - s) ** 2)) / (np.sum(
        (np.abs(s - np.mean(o)) + np.abs(o - np.mean(o))) ** 2))
    return ia

#%%
#------------------------------------------------------------------------------------------------------------------------------------------------------
#                                                   T O W E R   D A T A   P R O C E S S I N G
#------------------------------------------------------------------------------------------------------------------------------------------------------
# Function to process each item with conditions
def process_item(items, upper_limit, lower_limit):
    return [item if lower_limit < item < upper_limit else float('nan') for item in items]

# Function to process each date
def process_date(date, df_list, all_heights):
    day_data = {
        'time': [],
        'T': [], 'Td': [], 'q': [], 'p': [], 'theta': [], 'TKE': [],
        'H': [], 'LE': [], 'wind_speed': [], 'max_wind_speed': [],
        'wind_u': [], 'wind_v': [], 'wind_dir': [], 'bowen_ratio': [],
        'u_star': [], 'u_var': [], 'v_var': [], 'w_var': [], 'density': [],
        'cp': [], 'VPD': [], 'L': [], 'z_d_L': [], 'wind_w': [], 'NEE_flux': []
    }
    
    for df in df_list:
        df_sub = df.loc[df['date'] == date]
        if df_sub.empty:
            continue
        
        day_data['time'].append(df_sub['time'].values)

        day_data['T'].append(process_item(df_sub['Thermo_T'].astype(float).values, 999, -999))
        day_data['Td'].append(process_item(df_sub['Tdew'].astype(float).values - 273.15, 999, -999))
        day_data['q'].append(process_item(df_sub['specific_humidity'].astype(float).values * 1000, 999, -999))
        day_data['p'].append(df_sub['interpolated_pressure'].astype(float).values)
        
        # Potential Temperature Calculation
        T = df_sub['Thermo_T'].astype(float).values
        p = df_sub['interpolated_pressure'].astype(float).values
        cp = df_sub['air_heat_capacity'].astype(float).values
        theta = [(t + 273.15) * ((1000e2 / p[i]) ** (287 / cp[i])) for i, t in enumerate(T)]
        day_data['theta'].append(process_item(theta, 999, -999))

        day_data['TKE'].append(process_item(df_sub['TKE'].astype(float).values, 20, -20))
        day_data['H'].append(process_item(df_sub['H'].astype(float).values, 999, -999))
        day_data['LE'].append(process_item(df_sub['LE'].astype(float).values, 999, -999))
        day_data['wind_speed'].append(process_item(df_sub['wind_speed'].astype(float).values, 999, -999))
        day_data['max_wind_speed'].append(process_item(df_sub['max_wind_speed'].astype(float).values, 999, -999))
        day_data['wind_u'].append(process_item(df_sub['wind_u_correct'].astype(float).values, 999, -999))
        day_data['wind_v'].append(process_item(df_sub['wind_v_correct'].astype(float).values, 999, -999))
        day_data['wind_dir'].append(process_item(df_sub['wind_dir'].astype(float).values, 999, -999))
        day_data['bowen_ratio'].append(process_item(df_sub['bowen_ratio'].astype(float).values, 20, -20))
        day_data['u_star'].append(process_item(df_sub['u*'].astype(float).values, 1.2, -999))
        day_data['u_var'].append(process_item(df_sub['u_var'].astype(float).values, 999, -999))
        day_data['v_var'].append(process_item(df_sub['v_var'].astype(float).values, 999, -999))
        day_data['w_var'].append(process_item(df_sub['w_var'].astype(float).values, 999, -999))
        day_data['density'].append(process_item(df_sub['air_density'].astype(float).values, 9999, -9999))
        day_data['cp'].append(process_item(df_sub['air_heat_capacity'].astype(float).values, 9999, -9999))
        day_data['VPD'].append(process_item(df_sub['VPD'].astype(float).values / 100, 9999, -99))
        day_data['L'].append(process_item(df_sub['L'].astype(float).values / 1000, 99, -99))
        day_data['z_d_L'].append(process_item(df_sub['(z-d)/L'].astype(float).values, 9999, -9999))
        day_data['wind_w'].append(process_item(df_sub['w_unrot'].astype(float).values, 9999, -9999))
        day_data['NEE_flux'].append(process_item(df_sub['co2_flux'].astype(float).values, 9999, -9999))

    return day_data

# Initialize lists to store complete data
data_complete = {
    'time': [], 'T': [], 'Td': [], 'q': [], 'p': [], 'theta': [], 'TKE': [],
    'H': [], 'LE': [], 'wind_speed': [], 'max_wind_speed': [], 'wind_u': [],
    'wind_v': [], 'wind_dir': [], 'bowen_ratio': [], 'u_star': [], 'u_var': [],
    'v_var': [], 'w_var': [], 'density': [], 'cp': [], 'VPD': [], 'L': [],
    'z_d_L': [], 'wind_w': [], 'NEE_flux': []
}


#Read data
df_name_list = []
df_list = []

#Load data
for height in all_heights:
    if height < 100:
        stringer = '0'+str(int(height))
        if height < 10:
            stringer = '00'+str(int(height))
    else:
        stringer = str(int(height))
    locals()['DF_'+stringer+'m'] = pd.read_csv(tower_data_rootdir+'/DF_'+stringer+'m.csv')
    df_list.append(locals()['DF_'+stringer+'m'])
    df_name_list.append('DF_'+stringer+'m')

# Process each date
for date in date_list_general:
    day_data = process_date(date, df_list, all_heights)
    for key in data_complete:
        data_complete[key].append(day_data[key])

#Aggregate

# Initialise dictionaries
SCu_tower_mean, SCu_tower_std = {}, {}
DeepCu_tower_mean, DeepCu_tower_std = {}, {}
daily_tower = {}

#Append
for key in data_complete:
    if key == 'time':
        continue

    SCu_mean, SCu_std = [], []
    DeepCu_mean, DeepCu_std = [], []
    daily = []

    # Process each height
    for height in range(len(all_heights)):
        sub_SCu, sub_DeepCu, sub_daily = [], [], []

        # Process each date
        for date_check in date_list_general:
            data = np.array(data_complete[key][date_list_general.index(date_check)][height]).astype(float)[10*2:(22*2)+1]
            
            if date_check in SCu_list:
                sub_SCu.append(data)
            elif date_check in DeepCu_list:
                sub_DeepCu.append(data)
            
            sub_daily.extend(data)

        # Calculate mean and std for each category
        sub_SCu, sub_DeepCu = np.array(sub_SCu), np.array(sub_DeepCu)
        SCu_mean.append(np.nanmean(sub_SCu, axis=0))
        SCu_std.append(np.nanstd(sub_SCu, axis=0))
        DeepCu_mean.append(np.nanmean(sub_DeepCu, axis=0))
        DeepCu_std.append(np.nanstd(sub_DeepCu, axis=0))
        daily.append(np.array(sub_daily))
        
    # Store results in dictionaries
    SCu_tower_mean[key] = np.array(SCu_mean)
    SCu_tower_std[key] = np.array(SCu_std)
    DeepCu_tower_mean[key] = np.array(DeepCu_mean)
    DeepCu_tower_std[key] = np.array(DeepCu_std)
    daily_tower[key] = np.array(daily)


#%%
#------------------------------------------------------------------------------------------------------------------------------------------------------
#                                                   R A D I A T I O N   P R O C E S S I N G
#------------------------------------------------------------------------------------------------------------------------------------------------------

# Load data
df_ATTO_temp = pd.read_csv(tower_data_rootdir+'RadiationTemp_CloudRoots.csv')

# Define empyt lists to store data
variables = ['SWin', 'SWout', 'LWin', 'LWout', 'T_LWin', 'T_LWout', 'Qnet', 'LWnet', 'SWnet']
notation = {'SWin': 'SW_in', 'SWout': 'SW_out', 'LWin': 'LW_atm_correct', 'LWout': 'LW_terr_correct', 'T_LWin': 'T_LW_atm', 'T_LWout': 'T_LW_terr', 'Qnet': 'Qnet', 'LWnet': 'LWnet', 'SWnet': 'SWnet'}
day_rad = {'SWin': [], 'SWout': [], 'LWin': [], 'LWout': [], 'T_LWin': [], 'T_LWout': [], 'Qnet': [], 'LWnet': [], 'SWnet': [] }

# Process data
for date in date_list_general:
    df_sub = df_ATTO_temp.loc[(df_ATTO_temp['Date'] == date) & (df_ATTO_temp['Time_correct'] >= '06:00') & (df_ATTO_temp['Time_correct'] <= '18:10')]
    
    for key in notation:
        notation_key = notation[key]
        values = df_sub[f'{notation_key}'].values
        day_rad[key].append(values)


# Aggregate
SCu_rad_mean = {}
SCu_rad_std = {}
DeepCu_rad_mean = {}
DeepCu_rad_std = {}

daily_rad = {}

#Append
for key in day_rad:
   
    SCu_mean, SCu_std = [], []
    DeepCu_mean, DeepCu_std = [], []
    daily = []

    # Process each date
    sub_SCu, sub_DeepCu, sub_daily = [], [], []
    
    for date_check in date_list_general:
        data = np.array(day_rad[key][date_list_general.index(date_check)]).astype(float)
        
        if date_check in SCu_list:
            sub_SCu.append(data)
        elif date_check in DeepCu_list:
            sub_DeepCu.append(data)
        
        sub_daily.extend(data)


    # Calculate mean and std for each category
    sub_SCu, sub_DeepCu = np.array(sub_SCu), np.array(sub_DeepCu)
    
    SCu_mean = np.nanmean(sub_SCu, axis=0)
    SCu_std = np.nanstd(sub_SCu, axis=0)
    DeepCu_mean = np.nanmean(sub_DeepCu, axis=0)
    DeepCu_std = np.nanstd(sub_DeepCu, axis=0)
    daily = np.array(sub_daily)
        
    # Store results in dictionaries
    SCu_rad_mean[key] = SCu_mean
    SCu_rad_std[key] = SCu_std
    DeepCu_rad_mean[key] = DeepCu_mean
    DeepCu_rad_std[key] = DeepCu_std
    daily_rad[key] = daily
    

#%%
#------------------------------------------------------------------------------------------------------------------------------------------------------
#                                                   C O 2   P R O C E S S I N G
#------------------------------------------------------------------------------------------------------------------------------------------------------

# Load data
df = pd.read_csv(f'{directory_co2_data}/CloudRoots_CO2.csv')

# Define heights and create a dictionary to store data
day_CO2 = {f'{height}m': [] for height in heights_CO2}


# Process data
for date in date_list_general:
    df_sub = df.loc[(df['Date'] == date) & (df['Time_correct'] >= '06:00:00') & (df['Time_correct'] <= '18:00:00')]
    
    for key in day_CO2:
        data = df_sub[f'{key}'].values.flatten()
        day_CO2[f'{key}'].append(np.array(data))
        
        
#Aggregate
SCu_CO2_mean = {}
SCu_CO2_std = {}
DeepCu_CO2_mean = {}
DeepCu_CO2_std = {}

daily_CO2 = {}

#Append
for key in day_CO2:
       
    SCu_mean, SCu_std = [], []
    DeepCu_mean, DeepCu_std = [], []
    daily = []

    # Process each date
    sub_SCu, sub_DeepCu, sub_daily = [], [], []
    
    for date_check in date_list_general:
        data = np.array(day_CO2[f'{key}'][date_list_general.index(date_check)]).astype(float)
        
        if date_check in SCu_list:
            sub_SCu.append(data)
        elif date_check in DeepCu_list:
            sub_DeepCu.append(data)
        
        sub_daily.extend(data)


    # Calculate mean and std for each category
    sub_SCu, sub_DeepCu = np.array(sub_SCu), np.array(sub_DeepCu)
    
    SCu_mean = np.nanmean(sub_SCu, axis=0)
    SCu_std = np.nanstd(sub_SCu, axis=0)
    DeepCu_mean = np.nanmean(sub_DeepCu, axis=0)
    DeepCu_std = np.nanstd(sub_DeepCu, axis=0)
    daily = np.array(sub_daily)
        
    # Store results in dictionaries
    SCu_CO2_mean[key] = SCu_mean
    SCu_CO2_std[key] = SCu_std
    DeepCu_CO2_mean[key] = DeepCu_mean
    DeepCu_CO2_std[key] = DeepCu_std
    daily_CO2[key] = daily


#%%

#Aircraft - profiles
flight1_all = pd.read_csv(f"{directory_flight}/Aircraft_CloudRoots_flight1_profile.txt")
flight1_all = flight1_all.iloc[1:]

flight1_all_z = flight1_all['alt'].values.astype(float)
flight1_all_co2 = flight1_all['co2'].values.astype(float)
flight1_all_h2o = flight1_all['h2o'].values.astype(float)

flight2_all = pd.read_csv(f"{directory_flight}/Aircraft_CloudRoots_flight2_profile.txt")
flight2_all = flight2_all.iloc[1:]

flight2_all_z= flight2_all['alt'].values.astype(float)
flight2_all_co2 = flight2_all['co2'].values.astype(float)
flight2_all_h2o = flight2_all['h2o'].values.astype(float)

#Aircraft - transects
flight1_low = pd.read_csv(f"{directory_flight_detrended}/Flight1_low_detrended.csv")
flight1_low = flight1_low.iloc[1:]

flight1_low_z = flight1_low['alt'].values.astype(float)
flight1_low_co2 = flight1_low['co2'].values.astype(float)
flight1_low_LAT = flight1_low['lat'].values.astype(float)
flight1_low_LON = flight1_low['lon'].values.astype(float)
flight1_low_h2o = flight1_low['h2o'].values.astype(float)

flight1_low_z = np.nanmean(np.array(flight1_low_z)-120)
flight1_low_mean_co2 = np.nanmean(flight1_low_co2)
flight1_low_std_co2 = np.nanstd(flight1_low_co2)

flight1_low_mean_h2o = np.nanmean(flight1_low_h2o)
flight1_low_std_h2o = np.nanstd(flight1_low_h2o)

flight1_middle = pd.read_csv(f"{directory_flight_detrended}/Flight1_middle_detrended.csv")
flight1_middle = flight1_middle.iloc[1:]

flight1_middle_z = flight1_middle['alt'].values.astype(float)
flight1_middle_co2 = flight1_middle['co2'].values.astype(float)
flight1_middle_LAT = flight1_middle['lat'].values.astype(float)
flight1_middle_LON = flight1_middle['lon'].values.astype(float)
flight1_middle_h2o = flight1_middle['h2o'].values.astype(float)

flight1_middle_z = np.nanmean(np.array(flight1_middle_z)-120)
flight1_middle_mean_co2 = np.nanmean(flight1_middle_co2)
flight1_middle_std_co2 = np.nanstd(flight1_middle_co2)

fligh1_middle_mean_h2o = np.nanmean(flight1_middle_h2o)
flight1_middle_std_h2o = np.nanstd(flight1_middle_h2o)

flight1_high = pd.read_csv(f"{directory_flight_detrended}/Flight1_high_detrended.csv")
flight1_high = flight1_high.iloc[1:]

flight1_high_z = flight1_high['alt'].values.astype(float)
flight1_high_co2 = flight1_high['co2'].values.astype(float)
flight1_high_LAT = flight1_high['lat'].values.astype(float)
flight1_high_LON = flight1_high['lon'].values.astype(float)
flight1_high_h2o = flight1_high['h2o'].values.astype(float)

flight1_high_z = np.nanmean(np.array(flight1_high_z)-120)
flight1_high_mean_co2 = np.nanmean(flight1_high_co2)
flight1_high_std_co2 = np.nanstd(flight1_high_co2)

flight1_high_mean_h2o = np.nanmean(flight1_high_h2o)
flight1_high_std_h2o = np.nanstd(flight1_high_h2o)

flight2_low = pd.read_csv(f"{directory_flight_detrended}/Flight2_low_detrended.csv")
flight2_low = flight2_low.iloc[1:]

flight2_low_z = flight2_low['alt'].values.astype(float)
flight2_low_co2 = flight2_low['co2'].values.astype(float)
flight2_low_LAT = flight2_low['lat'].values.astype(float)
flight2_low_LON = flight2_low['lon'].values.astype(float)
flight2_low_h2o = flight2_low['h2o'].values.astype(float)

flight2_low_z = np.nanmean(np.array(flight2_low_z)-120)
flight2_low_mean_co2 = np.nanmean(flight2_low_co2)
flight2_low_std_co2 = np.nanstd(flight2_low_co2)

flight2_low_mean_h2o = np.nanmean(flight2_low_h2o)
flight2_low_std_h2o = np.nanstd(flight2_low_h2o)

flight2_middle = pd.read_csv(f"{directory_flight_detrended}/Flight2_middle_detrended.csv")
flight2_middle = flight2_middle.iloc[1:]

flight2_middle_z = flight2_middle['alt'].values.astype(float)
flight2_middle_co2 = flight2_middle['co2'].values.astype(float)
flight2_middle_LAT = flight2_middle['lat'].values.astype(float)
flight2_middle_LON = flight2_middle['lon'].values.astype(float)
flight2_middle_h2o = flight2_middle['h2o'].values.astype(float)

flight2_middle_z = np.nanmean(np.array(flight2_middle_z)-120)
flight2_middle_mean_co2 = np.nanmean(flight2_middle_co2)
flight2_middle_std_co2 = np.nanstd(flight2_middle_co2)

flight2_middle_mean_h2o = np.nanmean(flight2_middle_h2o)
flight2_middle_std_h2o = np.nanstd(flight2_middle_h2o)

flight2_high = pd.read_csv(f"{directory_flight_detrended}/Flight2_high_detrended.csv")
flight2_high = flight2_high.iloc[1:]

flight2_high_z = flight2_high['alt'].values.astype(float)
flight2_high_co2 = flight2_high['co2'].values.astype(float)
flight2_high_LAT = flight2_high['lat'].values.astype(float)
flight2_high_LON = flight2_high['lon'].values.astype(float)
flight2_high_h2o = flight2_high['h2o'].values.astype(float)

flight2_high_z = np.nanmean(np.array(flight2_high_z)-120)
flight2_high_mean_co2 = np.nanmean(flight2_high_co2)
flight2_high_std_co2 = np.nanstd(flight2_high_co2)

flight2_high_mean_h2o = np.nanmean(flight2_high_h2o)
flight2_high_std_h2o = np.nanstd(flight2_high_h2o)

#Respiration
resp_data = pd.read_excel(f'{rootdir_soil}Soil respiration ATTO 5-31 august 2022.xlsx', sheet_name='export_1')

resp_data = resp_data.drop([0,1])

resp_data = resp_data.loc[resp_data['LI-8250']=='C_Litter']

resp_data['LI-8250 UTC'] = pd.to_datetime(resp_data['LI-8250 UTC']) - timedelta(hours=4)
resp_data.set_index('LI-8250 UTC', inplace = True)

resp_data = resp_data[['LI-870.1']] 
resp_data['LI-870.1'] = resp_data['LI-870.1'].astype(float)

resp_data = resp_data.resample('1T').mean()
resp_interpolated = resp_data.interpolate()

#Cut the data
resp_data_c = resp_interpolated.loc[(resp_interpolated.index >= '2022-08-09 06:00:00') & (resp_interpolated.index <= '2022-08-18 18:00:00')]

resp_values = []
for date in date_list_general:
    sel = resp_data_c.loc[resp_data_c.index.strftime('%Y-%m-%d') == date]
    sel = sel.loc[(sel.index.time >= pd.to_datetime('06:00').time()) & (sel.index.time <= pd.to_datetime('18:00').time())]
    
    times = sel.index

    # Generate a range of time points every 30 minutes
    start_time = times.min()
    end_time = times.max()
    time_range = pd.date_range(start=start_time, end=end_time, freq='30T')

    # Find indices where the index times match these time points
    matching_indices = times[times.isin(time_range)]
    
    values = sel.loc[sel.index[sel.index.isin(time_range)]]
    
    resp_values.extend(values['LI-870.1'].values)


Resp = xr.Dataset(
    {
        'Resp': (('time'), resp_values)
    },
    coords={
        'time': x_OBS_30_daily
    })


#SCu and DeepCu data
resp_SCu = Resp.sel(time=Resp['time'].dt.strftime('%Y-%m-%d').isin(SCu_list)) 
hours = resp_SCu.time.dt.hour
minutes = resp_SCu.time.dt.minute
resp_SCu['time'] = np.array([f'{hour:02d}.{minute:02d}' for hour, minute in zip(hours, minutes)]).astype(float)
SCu_Resp_mean = resp_SCu.groupby('time').mean(dim='time').Resp.values
SCu_Resp_std = resp_SCu.groupby('time').std(dim='time').Resp.values

resp_DeepCu = Resp.sel(time=Resp['time'].dt.strftime('%Y-%m-%d').isin(DeepCu_list))
hours = resp_DeepCu.time.dt.hour
minutes = resp_DeepCu.time.dt.minute
resp_DeepCu['time'] = np.array([f'{hour:02d}.{minute:02d}' for hour, minute in zip(hours, minutes)]).astype(float)
DeepCu_Resp_mean = resp_DeepCu.groupby('time').mean(dim='time').Resp.values
DeepCu_Resp_std = resp_DeepCu.groupby('time').std(dim='time').Resp.values

#%%
#------------------------------------------------------------------------------------------------------------------------------------------------------
#                                                    V E R T I C A L   P R O F I L E S 
#------------------------------------------------------------------------------------------------------------------------------------------------------

SCu_list_vertical = ['20220809','20220810','20220811','20220815','20220817','20220818']
DeepCu_list_vertical = ['20220812','20220813','20220814','20220816']

time_list_plot = np.array([1000, 1300, 1600, 1900, 2200])

def append_values(day_list, category, ref=False):
    suffix = ' Reference' if ref else ''
    path = f'{rootdir_vertical}/{category}{suffix} Profiles/{{day}}/{{time}}UTC.csv'
    z_values, p_values, theta_values, q_values, RH_values, T_values, Td_values, wind_u_values, wind_v_values, wind_speed_values, wind_dir_values = [], [], [], [], [], [], [], [], [], [], []

    for day in day_list:
        sub_z, sub_p, sub_theta, sub_q, sub_RH, sub_T, sub_Td, sub_wind_u, sub_wind_v, sub_wind_speed, sub_wind_dir = [], [], [], [], [], [], [], [], [], [], []
        for time in time_list_plot:
            df = pd.read_csv(path.format(day=day, time=time))
            sub_z.append(df['z'].values)
            sub_p.append(df['p'].values)
            sub_theta.append(df['theta'].values)
            sub_q.append(df['q'].values)
            sub_RH.append(df['RH'].values)
            sub_T.append(df['T'].values)
            sub_Td.append(df['Td'].values)
            sub_wind_u.append(df['wind_u'].values)
            sub_wind_v.append(df['wind_v'].values)
            sub_wind_speed.append(df['wind_speed'].values)
            sub_wind_dir.append(df['wind_dir'].values)

        z_values.append(sub_z)
        p_values.append(sub_p)
        theta_values.append(sub_theta)
        q_values.append(sub_q)
        RH_values.append(sub_RH)
        T_values.append(sub_T)
        Td_values.append(sub_Td)
        wind_u_values.append(sub_wind_u)
        wind_v_values.append(sub_wind_v)
        wind_speed_values.append(sub_wind_speed)
        wind_dir_values.append(sub_wind_dir)

    return z_values, p_values, theta_values, q_values, RH_values, T_values, Td_values, wind_u_values, wind_v_values, wind_speed_values, wind_dir_values

SCu_values_z_ref, SCu_values_p_ref, SCu_values_theta_ref, SCu_values_q_ref, SCu_values_RH_ref, SCu_values_T_ref, SCu_values_Td_ref, SCu_values_wind_u_ref, SCu_values_wind_v_ref, SCu_values_wind_speed_ref, SCu_values_wind_dir_ref = append_values(SCu_list_vertical, 'SCu', ref=True)
DeepCu_values_z_ref, DeepCu_values_p_ref, DeepCu_values_theta_ref, DeepCu_values_q_ref, DeepCu_values_RH_ref, DeepCu_values_T_ref, DeepCu_values_Td_ref, DeepCu_values_wind_u_ref, DeepCu_values_wind_v_ref, DeepCu_values_wind_speed_ref, DeepCu_values_wind_dir_ref = append_values(DeepCu_list_vertical, 'DeepCu', ref=True)

SCu_values_z, SCu_values_p, SCu_values_theta, SCu_values_q, SCu_values_RH, SCu_values_T, SCu_values_Td, SCu_values_wind_u, SCu_values_wind_v, SCu_values_wind_speed, SCu_values_wind_dir = append_values(SCu_list_vertical, 'SCu')
DeepCu_values_z, DeepCu_values_p, DeepCu_values_theta, DeepCu_values_q, DeepCu_values_RH, DeepCu_values_T, DeepCu_values_Td, DeepCu_values_wind_u, DeepCu_values_wind_v, DeepCu_values_wind_speed, DeepCu_values_wind_dir = append_values(DeepCu_list_vertical, 'DeepCu')

#%%
#------------------------------------------------------------------------------------------------------------------------------------------------------
#                                                    B O U N D A R Y   L A Y E R  &  C L O U D S
#------------------------------------------------------------------------------------------------------------------------------------------------------


#------------------------------------------------------------------------------
#       L O A D   O B S E R V A T I O N S -   A B L  &  C L O U D S 
#------------------------------------------------------------------------------

#Sounding variables

# Load data
df = pd.read_csv(rootdir_vertical + 'Timeseries Reference.csv', sep=",")

# Define dates and times
SCu_list_s = [20220809, 20220810, 20220811, 20220815, 20220817, 20220818]
DeepCu_list_s = [20220812, 20220813, 20220814, 20220816]
time_list = np.array([1000, 1300, 1600, 1900, 2200])
time_list_str = ['6', '9', '12', '15', '18']

# Function to process values for a list of dates
def process_values(date_list):
    values_LCL, values_LFC, values_EL = [], [], []
    values_cape, values_cin, values_h = [], [], []

    for date in date_list:
        LCL_sub, LFC_sub, EL_sub = [], [], []
        cape_sub, cin_sub, h_sub = [], [], []

        df_sel = df.loc[df['Date'] == date]

        for time in time_list:
            df_sub = df_sel.loc[df_sel['Time'] == time]
            LCL_sub.extend(df_sub['LCL'].values)
            LFC_sub.extend(df_sub['LFC'].values)
            EL_sub.extend(df_sub['EL'].values)
            cape_sub.extend(df_sub['cape'].values)
            cin_sub.extend(df_sub['cin'].values)
            h_sub.extend(df_sub['h'].values)

        if date == 20220818:
            nan_padding = [float('nan')] * 2
            LCL_sub.extend(nan_padding)
            LFC_sub.extend(nan_padding)
            EL_sub.extend(nan_padding)
            cape_sub.extend(nan_padding)
            cin_sub.extend(nan_padding)
            h_sub.extend(nan_padding)

        values_LCL.append(LCL_sub)
        values_LFC.append(LFC_sub)
        values_EL.append(EL_sub)
        values_cape.append(cape_sub)
        values_cin.append(cin_sub)
        values_h.append(h_sub)

    return values_LCL, values_LFC, values_EL, values_cape, values_cin, values_h

# Process SCu and DeepCu values
SCu_values_LCL, SCu_values_LFC, SCu_values_EL, SCu_values_cape, SCu_values_cin, SCu_values_h = process_values(SCu_list_s)
DeepCu_values_LCL, DeepCu_values_LFC, DeepCu_values_EL, DeepCu_values_cape, DeepCu_values_cin, DeepCu_values_h = process_values(DeepCu_list_s)

#ABL information
df_ABL = pd.read_csv(f'{cloud_info}ABL_SCu_DeepCu_aggregates.csv')
abl_SCu_mean = df_ABL['SCu_mean'].values
abl_SCu_std = df_ABL['SCu_std'].values
abl_DeepCu_mean = df_ABL['DeepCu_mean'].values
abl_DeepCu_std = df_ABL['DeepCu_std'].values

#Cloud information
ds_SCu = xr.open_dataset(f'{cloud_info}SCu_aggregate_CloudRoots_complete.nc')
ds_SCu = ds_SCu.sel(time=ds_SCu['time'].dt.strftime('%Y-%m-%d').isin(SCu_list)) 
time_df = ds_SCu['time'].to_dataframe()
time_df['hour_minute'] = time_df.index.hour * 100 + time_df.index.minute
ds_SCu = ds_SCu.assign_coords(hour_minute=('time', time_df['hour_minute']))
SCu_mean = ds_SCu.groupby('hour_minute').mean(dim='time')
classification_mean = getattr(SCu_mean.sel(hour_minute=(SCu_mean['hour_minute'] >= 600) & (SCu_mean['hour_minute'] <= 1800)),'classification').values
time_MIRA = getattr(SCu_mean.sel(hour_minute=(SCu_mean['hour_minute'] >= 600) & (SCu_mean['hour_minute'] <= 1800)),'hour_minute').values

#cloud base
cloud_base_mean = getattr(SCu_mean.sel(hour_minute=(SCu_mean['hour_minute'] >= 600) & (SCu_mean['hour_minute'] <= 1800)),'cloud_base_mean').values
cloud_top_mean = getattr(SCu_mean.sel(hour_minute=(SCu_mean['hour_minute'] >= 600) & (SCu_mean['hour_minute'] <= 1800)),'cloud_top_mean').values
cloud_thickness = getattr(SCu_mean.sel(hour_minute=(SCu_mean['hour_minute'] >= 600) & (SCu_mean['hour_minute'] <= 1800)),'cloud_thickness_mean').values


#%%
#------------------------------------------------------------------------------------------------------------------------------------------------------
#                                                          D A L E S   P R  O C E S S I N G
#------------------------------------------------------------------------------------------------------------------------------------------------------

#Profiles
profiles = xr.open_dataset(f'{rootdir_DALES}profiles.034.nc')

#Height arrays
profiles_zt = profiles['zt'].values+40
profiles_zm = profiles['zm'].values+40

#Variables
profiles_theta = profiles['thl'].values                                                                                       #zt
profiles_q = profiles['qt'].values*1000                                                                                         #zt
profiles_ql = profiles['ql'].values                                                                                         #zt
profiles_u = profiles['u'].values                                                                                           #zt
profiles_v = profiles['v'].values                                                                                           #zt
profiles_wCO2 = profiles['wsv001t'].values/1000                                                                                  #zm
profiles_LWout = profiles['lwu'].values                                                                                     #zm
profiles_LWin = profiles['lwd'].values*-1                                                                                   #zm
profiles_SWout = profiles['swu'].values                                                                                     #zm
profiles_SWin = profiles['swd'].values*-1                                                                                   #zm
profiles_LWout_clear = profiles['lwuca'].values                                                                             #zm
profiles_LWin_clear = profiles['lwdca'].values*-1                                                                           #zm
profiles_SWout_clear = profiles['swuca'].values                                                                             #zm
profiles_SWin_clear = profiles['swdca'].values*-1                                                                           #zm
profiles_CO2 = profiles['sv001'].values/1000                                                                                     #zt
profiles_ac = profiles['cfrac'].values                                                                                      #zt

profiles_SWnet = profiles_SWin - profiles_SWout
profiles_LWnet = profiles_LWin - profiles_LWout
profiles_Qnet = profiles_SWnet + profiles_LWnet

#------------------------------------------------------------------------------------------------------------------------------------------------------

#tmser
tmser = xr.open_dataset(f'{rootdir_DALES}tmser.034.nc')

#Variables
tmser_ac = tmser['cfrac'].values
tmser_zb = tmser['zb'].values
tmser_zc_av = tmser['zc_av'].values
tmser_zc_max = tmser['zc_max'].values
tmser_zi = tmser['zi'].values
tmser_we = tmser['we'].values
tmser_Qnet = tmser['Qnet'].values
tmser_H = tmser['H'].values
tmser_LE = tmser['LE'].values

#------------------------------------------------------------------------------------------------------------------------------------------------------

#wCO2 (surface)
df = open(f'{rootdir_DALES}tmlsm.034')
headers = None
data = {}

count = 0
for line in df:
    if count == 0:
        split = line.split(None)[1:]
        headers = split
        for item in headers:
            data[item] = []
        count += 1
    elif count == 1:
        count += 1
        continue
    else:
        split = line.split(None)[:]
        sub = 0
        for item in headers:
            data[item].append(split[sub])
            sub += 1
        count += 1


wCO2_s = (np.array(data['wco2']).astype(float))
DALES_Resp = (np.array(data['Resp']).astype(float))
DALES_Resp_AGS = DALES_Resp*(1000/44.009)

#--------------------------------------------------------------------------------
wthvt = profiles.wthvt
time_ref = profiles.time.values

abl_1D = []
idx_1D = []
for i in range(len(time_ref)):
    values = wthvt.sel(time=time_ref[i]).values
    idx = np.where(values == np.min(values))[0][0]
    abl_val = profiles.zt[idx].values
    
    if abl_val == 0 or abl_val > 2000 or abl_val < 500:
        abl_val = np.nan
        idx = np.nan
    
    abl_1D.append(abl_val+40)
    idx_1D.append(idx)


#--------------------------------------------------------------------------------

#Initial profiles
init_prof = pd.read_csv(f'{rootdir_DALES}prof.inp.034', sep=" ", header=None,skiprows=2)
z_init_prof = init_prof[0].values
theta_init_prof = init_prof[1].values
q_init_prof = init_prof[2].values*1000

init_prof_scalar = pd.read_csv(f'{rootdir_DALES}scalar.inp.034', sep="   ", header=None,skiprows=2)
z_init_prof_scalar = init_prof_scalar[0].values
CO2_init_prof_scalar = init_prof_scalar[1].values/1000


#%%
#------------------------------------------------------------------------------------------------------------------------------------------------------
#                                                             V I S U A L I S A T I O N 
#------------------------------------------------------------------------------------------------------------------------------------------------------

#=================
#    Figure 1
#=================

fig,ax = plt.subplots(3,3,dpi=72,figsize=(8.27*1.5, 11.69*1.3)) #row, column / width, height

ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9 = ax.flatten()

ax1.axis('off')
# ax2.axis('off')
ax3.axis('off')
ax4.axis('off')
# ax5.axis('off')
ax6.axis('off')
ax7.axis('off')
# ax8.axis('off')
ax9.axis('off')

ax1 = plt.subplot(3, 3, 1, projection=ccrs.PlateCarree())  # Map on ax1
ax2 = plt.subplot(3, 3, 2)  # Standard (non-map) subplot for line plot on ax2
ax3 = plt.subplot(3, 3, 3, projection=ccrs.PlateCarree())  # Map on ax3
ax4 = plt.subplot(3, 3, 4, projection=ccrs.PlateCarree())  # Map on ax4
ax5 = plt.subplot(3, 3, 5)  # Standard (non-map) subplot for line plot on ax5
ax6 = plt.subplot(3, 3, 6, projection=ccrs.PlateCarree())  # Map on ax6
ax7 = plt.subplot(3, 3, 7, projection=ccrs.PlateCarree())  # Map on ax7
ax8 = plt.subplot(3, 3, 8)  # Standard (non-map) subplot for line plot on ax8
ax9 = plt.subplot(3, 3, 9, projection=ccrs.PlateCarree())  # Map on ax9

#Correlation - ax2, ax5, ax8
heights = ['low','middle','high']
axis = [8, 5, 2]

for height in heights:    
    
    count = axis[heights.index(height)]
    
    #Flight 2
    flight2 = pd.read_csv(f"{directory_flight_detrended}/Flight2_{height}_detrended.csv")
    flight2 = flight2.drop(flight2.index[0])

    flight2_h2o = flight2['h2o'].values.astype(float)
    flight2_co2 = flight2['co2'].values.astype(float)
    
    flight2_h2o_mean = np.nanmean(flight2_h2o)
    flight2_co2_mean = np.nanmean(flight2_co2)
    
    flight2_h2o_prime = flight2_h2o - flight2_h2o_mean
    flight2_co2_prime = flight2_co2 - flight2_co2_mean
    
    abs_flight2_h2o_prime = np.abs(flight2_h2o_prime)
    abs_flight2_co2_prime = np.abs(flight2_co2_prime)
    
    flight2_h2o_prime = flight2_h2o_prime
    flight2_co2_prime = flight2_co2_prime
        
    
    locals()[f'ax{count}'].grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)

    # Add horizontal and vertical lines through the origin
    locals()[f'ax{count}'].axhline(0, color='k',linewidth=1)
    locals()[f'ax{count}'].axvline(0, color='k',linewidth=1)
    
    locals()[f'ax{count}'].scatter(flight2_co2_prime,flight2_h2o_prime,s=0.2,color='tab:blue')     
    
    #Linear regression
    slope, intercept, r_value, p_value, std_err = stats.linregress(flight2_co2_prime,flight2_h2o_prime)
    
    x = np.arange(-2,2.2,0.2)
    locals()[f'ax{count}'].plot(x, intercept + slope * x, color='tab:red',linewidth=axiswidth*2)       

    #Statistics
    res = stats.pearsonr(flight2_co2_prime,flight2_h2o_prime)
    
    if height == 'low' or height == 'middle':
        text = locals()[f'ax{count}'].text(0.07, 0.1, f' r = {np.round(res[0],2)}',transform=locals()[f'ax{count}'].transAxes,fontsize=plt.rcParams['font.size']*1.2)
        # text.set_path_effects([PathEffects.withStroke(linewidth=3, foreground='w')])
        # plt.draw()
    else:
        text = locals()[f'ax{count}'].text(0.62, 0.1, f' r = {np.round(res[0],2)}',transform=locals()[f'ax{count}'].transAxes,fontsize=plt.rcParams['font.size']*1.2)
        # text.set_path_effects([PathEffects.withStroke(linewidth=3, foreground='w')])
        # plt.draw()
    
    
    #Detailing
    locals()[f'ax{count}'].set_ylabel(r"$\mathrm{H_2 O}'$ (g kg$^{-1}$)")
    locals()[f'ax{count}'].set_xlabel(r"$\mathrm{CO_2}'$ (ppm)")

    locals()[f'ax{count}'].tick_params(axis='x')
    locals()[f'ax{count}'].tick_params(axis='y')
    
    locals()[f'ax{count}'].set_xticks(np.arange(-2,3,1),np.arange(-2,3,1))
    locals()[f'ax{count}'].set_yticks(np.arange(-2,3,1),np.arange(-2,3,1))
    
    locals()[f'ax{count}'].yaxis.set_minor_locator(AutoMinorLocator())
    locals()[f'ax{count}'].xaxis.set_minor_locator(AutoMinorLocator())
    locals()[f'ax{count}'].spines['right'].set_visible(False)
    locals()[f'ax{count}'].spines['top'].set_visible(False)
    
    locals()[f'ax{count}'].set_xlim(-2,2)
    locals()[f'ax{count}'].set_ylim(-2,2)
    
    #Histograms
    ax_histx = locals()[f'ax{count}'].inset_axes([0, 1.02, 1, 0.25])
    
    ax_histx.hist(flight2_co2_prime, bins=50, color='tab:blue',alpha=0.5)
    
    ax_histx.tick_params(axis='x')    
    ax_histx.spines['right'].set_visible(False)
    ax_histx.spines['top'].set_visible(False)
    ax_histx.spines['left'].set_linewidth(False)  
    ax_histx.set_xticks(np.arange(-2,2.2,0.2),[])
    ax_histx.set_yticks([],[])
    ax_histx.set_xlim(-2,2)
    
    ax_histy = locals()[f'ax{count}'].inset_axes([1.02, 0, 0.25, 1])
    
    ax_histy.hist(flight2_h2o_prime, bins=50, color='tab:blue',alpha=0.5,orientation='horizontal')
    
    ax_histy.tick_params(axis='y')
    
    ax_histy.spines['right'].set_visible(False)
    ax_histy.spines['top'].set_visible(False)
    ax_histy.spines['bottom'].set_linewidth(False) 
    ax_histy.set_xticks([],[])
    ax_histy.set_yticks(np.arange(-2,2.2,0.2),[])
    ax_histy.set_ylim(-2,2)



#Add aircraft tracks - CO2
axis = [7, 4, 1]
rotations = [-90,-225,-45]

tiler = cartopy.io.img_tiles.GoogleTiles(style="satellite") #QuadtreeTiles(desired_tile_form='RGB') #GoogleTiles(style="satellite")
tiler = cartopy.io.img_tiles.OSM() 

for height in heights:    
    
    count = axis[heights.index(height)]
    
    flight2 = pd.read_csv(f"{directory_flight_detrended}/Flight2_{height}_detrended.csv")
    flight2 = flight2.drop(flight2.index[0])

    flight2_h2o = flight2['h2o'].values.astype(float)
    flight2_co2 = flight2['co2'].values.astype(float)
    
    flight2_h2o_mean = np.nanmean(flight2_h2o)
    flight2_co2_mean = np.nanmean(flight2_co2)
    
    flight2_h2o_prime = flight2_h2o - flight2_h2o_mean
    flight2_co2_prime = flight2_co2 - flight2_co2_mean
    
    abs_flight2_h2o_prime = np.abs(flight2_h2o_prime)
    abs_flight2_co2_prime = np.abs(flight2_co2_prime)
    
    flight2_h2o_prime = flight2_h2o_prime
    flight2_co2_prime = flight2_co2_prime
    
    #Add track
    norm = Normalize(vmin=-1, vmax=1)
    cb = locals()[f'ax{count}'].scatter(locals()[f'flight2_{height}_LON'], locals()[f'flight2_{height}_LAT'], c=flight2_co2_prime,cmap='coolwarm', s=10,edgecolors=None,linewidth=0.05,norm=norm) 

    locals()[f'ax{count}'].scatter(-59.0217,-2.1819,marker='^',color='tab:blue',s=150,edgecolors='k',linewidth=1) 
    locals()[f'ax{count}'].scatter(-59.0056354,-2.1458835,marker='^',color='tab:red',s=150,edgecolors='k',linewidth=1) 
    
    #Add map
    new_lat_min, new_lat_max, new_lon_min, new_lon_max = get_square_bounds(np.min(locals()[f'flight2_{height}_LAT']), np.max(locals()[f'flight2_{height}_LAT']), np.min(locals()[f'flight2_{height}_LON']), np.max(locals()[f'flight2_{height}_LON']))

    extent_small = [locals()[f'ax{count}'].get_xlim()[0], locals()[f'ax{count}'].get_xlim()[1], locals()[f'ax{count}'].get_ylim()[0], locals()[f'ax{count}'].get_ylim()[1]]
    extent = [-60,-58,-4,-2]

    locals()[f'ax{count}'].set_extent(extent,crs=ccrs.PlateCarree())
    locals()[f'ax{count}'].add_image(tiler, 10,alpha=0.4)     
    
    locals()[f'ax{count}'].set_aspect('auto')  # Make map fit the subplot

    #Detailing    
    new_lat_min = -2.3
    new_lat_max = -2.05
    new_lon_min = -59.1
    new_lon_max = -58.85
    locals()[f'ax{count}'].set_ylim(new_lat_min, new_lat_max)
    locals()[f'ax{count}'].set_xlim(new_lon_min, new_lon_max)
    
    gl = locals()[f'ax{count}'].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                  linewidth=1.5, color='gray', alpha=0.2, linestyle='--')
    
    gl.top_labels = False
    gl.right_labels = False
    
    # #Set x and y ticks of gl within extent

    gl.xlocator = mticker.FixedLocator(np.arange(-60,-57,0.05))
    gl.ylocator = mticker.FixedLocator(np.arange(-4,-1,0.05))
    

cax = fig.add_axes([0.18, -0.02, 0.7, 0.02]) #x,y,w,h
cbar = fig.colorbar(cb,orientation='horizontal',cax=cax,extend='both')
cbar.ax.set_title(r"$\mathrm{\varphi}'$ (ppm or g kg$^{-1}$)", loc='center', pad=10)
cbar.solids.set(alpha=1)

tick_values = np.round(np.arange(-1, 1.2, 0.2), 1)

# Replace -0.0 with 0.0
tick_labels = [str(tick) if tick != 0 else '0' for tick in tick_values]

cbar.set_ticks(tick_values)  
cbar.set_ticklabels(tick_labels)
cbar.minorticks_on()

locals()[f'ax{count}'].scatter(-59.0217,-2.1819,marker='^',color='tab:blue',s=150,edgecolors='k',linewidth=1,label='CAMPINA') 
locals()[f'ax{count}'].scatter(-59.0056354,-2.1458835,marker='^',color='tab:red',s=150,edgecolors='k',linewidth=1, label='ATTO') 
locals()[f'ax{count}'].legend(loc='lower left', bbox_to_anchor=(0.45, 1.69), ncol=2,
           borderaxespad=0, frameon=False, markerscale=1.2)

#Add aircraft tracks - H2O
west, south, east, north = -59.05, -2.32, -58.90, -2.06
axis = [9, 6, 3]
rotations = [-90,-225,-45]

for height in heights:    
    
    count = axis[heights.index(height)]

    flight2 = pd.read_csv(f"{directory_flight_detrended}/Flight2_{height}_detrended.csv")
    flight2 = flight2.drop(flight2.index[0])

    flight2_h2o = flight2['h2o'].values.astype(float)
    flight2_co2 = flight2['co2'].values.astype(float)
    
    flight2_h2o_mean = np.nanmean(flight2_h2o)
    flight2_co2_mean = np.nanmean(flight2_co2)
    
    flight2_h2o_prime = flight2_h2o - flight2_h2o_mean
    flight2_co2_prime = flight2_co2 - flight2_co2_mean
    
    abs_flight2_h2o_prime = np.abs(flight2_h2o_prime)
    abs_flight2_co2_prime = np.abs(flight2_co2_prime)
    
    flight2_h2o_prime = flight2_h2o_prime
    flight2_co2_prime = flight2_co2_prime
    
    #Data
    cb = locals()[f'ax{count}'].scatter(locals()[f'flight2_{height}_LON'], locals()[f'flight2_{height}_LAT'], c=flight2_h2o_prime,cmap='coolwarm', s=10,edgecolors=None,linewidth=0.05,norm=norm) 
    locals()[f'ax{count}'].scatter(-59.0217,-2.1819,marker='^',color='tab:blue',s=150,edgecolors='k',linewidth=1) 
    locals()[f'ax{count}'].scatter(-59.0056354,-2.1458835,marker='^',color='tab:red',s=150,edgecolors='k',linewidth=1) 
    
    #Add map
    new_lat_min, new_lat_max, new_lon_min, new_lon_max = get_square_bounds(np.min(locals()[f'flight2_{height}_LAT']), np.max(locals()[f'flight2_{height}_LAT']), np.min(locals()[f'flight2_{height}_LON']), np.max(locals()[f'flight2_{height}_LON']))

    extent_small = [locals()[f'ax{count}'].get_xlim()[0], locals()[f'ax{count}'].get_xlim()[1], locals()[f'ax{count}'].get_ylim()[0], locals()[f'ax{count}'].get_ylim()[1]]
    extent = [-60,-58,-4,-2]

    locals()[f'ax{count}'].set_extent(extent,crs=ccrs.PlateCarree())
    locals()[f'ax{count}'].add_image(tiler, 10,alpha=0.4) 

    locals()[f'ax{count}'].set_aspect('auto')  # Make map fit the subplot

    #Detailing
    new_lat_min = -2.3
    new_lat_max = -2.05
    new_lon_min = -59.1
    new_lon_max = -58.85
    locals()[f'ax{count}'].set_ylim(new_lat_min, new_lat_max)
    locals()[f'ax{count}'].set_xlim(new_lon_min, new_lon_max)
    
    gl = locals()[f'ax{count}'].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                  linewidth=1.5, color='gray', alpha=0.2, linestyle='--')
    
    gl.top_labels = False
    gl.right_labels = True
    gl.left_labels = False
    
    #Set x and y ticks of gl within extent

    gl.xlocator = mticker.FixedLocator(np.arange(-60,-57,0.05))
    gl.ylocator = mticker.FixedLocator(np.arange(-4,-1,0.05))
    

#Add north arrow
arrow = fig.add_axes([0.05, 0.89+0.08, 0.1, 0.05]) 
north_arrow = mpimg.imread(f'C:/Users/.../Northarrow.png')  
#north_arrow = rotate(north_arrow, rotations[heights.index(height)], reshape=True)
imagebox = OffsetImage(north_arrow, zoom=0.15)  # Adjust zoom to change icon size
arrow.set_axis_off()
arrow.patch.set_facecolor('none')
ab = AnnotationBbox(imagebox, (0.5, 0.5), frameon=False)
arrow.add_artist(ab)

#Scalebar
scale_km = 20
scale_bar = fig.add_axes([0.051, 0.85+0.11, 0.3, 0.05]) 
scale_bar.set_axis_off()
scale_bar.patch.set_facecolor('none')
center_lat = (extent[2] + extent[3]) / 2
longitude_scale = scale_km / (111.0 * np.cos(np.radians(center_lat)))  
scale_height = 0.05  
scale_rect = patches.Rectangle((0.3, 0.5), longitude_scale, scale_height, linewidth=1, edgecolor='k', facecolor='k')

scale_bar.add_patch(scale_rect)
scale_bar.text(0.3 + longitude_scale / 2, 0.6 + scale_height * 1.5, s=f'{scale_km} km', horizontalalignment='center')

#Title
ax2.set_title(r'\textbf{Cloud Layer (z$_H$ $\sim$ 3000 m)}',fontsize=plt.rcParams['font.size']*1.2)
ax5.set_title(r'\textbf{Subcloud Layer (z$_M$ $\sim$ 1100 m)}',fontsize=plt.rcParams['font.size']*1.2)
ax8.set_title(r'\textbf{Roughness Sublayer (z$_L$ $\sim$ 200 m)}',fontsize=plt.rcParams['font.size']*1.2)

#Concentration
ax1.set_title(r"$\mathrm{CO_2}'$",fontsize=plt.rcParams['font.size'])
ax4.set_title(r"$\mathrm{CO_2}'$",fontsize=plt.rcParams['font.size'])
ax7.set_title(r"$\mathrm{CO_2}'$",fontsize=plt.rcParams['font.size'])

ax3.set_title(r"$\mathrm{H_2 O}'$",fontsize=plt.rcParams['font.size'])
ax6.set_title(r"$\mathrm{H_2 O}'$",fontsize=plt.rcParams['font.size'])
ax9.set_title(r"$\mathrm{H_2 O}'$",fontsize=plt.rcParams['font.size'])

#Labels
ax1.text(-0.75,0.55, '(c)',transform=ax1.transAxes,fontsize=plt.rcParams['font.size']*1.2)
ax4.text(-0.75,0.55, '(b)',transform=ax4.transAxes,fontsize=plt.rcParams['font.size']*1.2)
ax7.text(-0.75,0.55, '(a)',transform=ax7.transAxes,fontsize=plt.rcParams['font.size']*1.2)

plt.subplots_adjust(wspace=0.5,hspace=1.2)

plt.savefig(f'{outdir}/Figure_1.jpg',bbox_inches='tight',dpi=500)
plt.savefig(f'{outdir}/Figure_1.svg',bbox_inches='tight')
plt.savefig(f'{outdir}/Figure_1.pdf',bbox_inches='tight')

plt.show()
plt.close()

#%%
#=================
# Figure 2
#=================

fig,ax = plt.subplots(2,2,dpi=72,figsize=(8.27*1.5, 11.69))  #row, column / width, height

ax1, ax2, ax3, ax4 = ax.flatten()

#=================
# Figure 2a
#=================

statistics = {}

idx = np.where(all_heights == 50)

#OBS
ax1.plot(x_OBS_30,SCu_tower_mean['H'][idx][0],linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=5,markeredgecolor='k',markeredgewidth=0.5,label='OBS H')
ax1.fill_between(x_OBS_30,SCu_tower_mean['H'][idx][0]-SCu_tower_std['H'][idx][0],SCu_tower_mean['H'][idx][0]+SCu_tower_std['H'][idx][0],color='tab:red',alpha=0.2)

ax1.plot(x_OBS_30,SCu_tower_mean['LE'][idx][0],linewidth=axiswidth*2,color='tab:blue',ls="",marker='o',markersize=5,markeredgecolor='k',markeredgewidth=0.5,label='OBS LE')
ax1.fill_between(x_OBS_30,SCu_tower_mean['LE'][idx][0]-SCu_tower_std['LE'][idx][0],SCu_tower_mean['LE'][idx][0]+SCu_tower_std['LE'][idx][0],color='tab:blue',alpha=0.2)

#LES
ax1.plot(x_DALES,tmser_H,linewidth=axiswidth*3,color='tab:red',ls="-",label='DALES H')
ax1.plot(x_DALES,tmser_LE,linewidth=axiswidth*3,color='tab:blue',ls="-",label='DALES LE')

#Statistics

#H
x = SCu_tower_mean['H'][idx][0]
y_old = tmser_H

y = []
for time in x_OBS_30:
    index = find_closest(np.array(x_DALES),time)
    y.append(y_old[index])

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()
R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['H'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'

#After 10 LT
x = SCu_tower_mean['H'][idx][0]
y_old = tmser_H

y = []
for time in x_OBS_30:
    index = find_closest(np.array(x_DALES),time)
    y.append(y_old[index])

idx_10 = find_closest(np.array(x_OBS_30),10)
x = x[idx_10:]
y = y[idx_10:]

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()
R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['H-10LT'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'

#LE
x = SCu_tower_mean['LE'][idx][0]
y_old = tmser_LE

y = []
for time in x_OBS_30:
    index = find_closest(np.array(x_DALES),time)
    y.append(y_old[index])
    
nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()
R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['LE'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'

#After 10 LT
x = SCu_tower_mean['LE'][idx][0]
y_old = tmser_LE

y = []
for time in x_OBS_30:
    index = find_closest(np.array(x_DALES),time)
    y.append(y_old[index])

idx_10 = find_closest(np.array(x_OBS_30),10)
x = x[idx_10:]
y = y[idx_10:]

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()
R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['LE-10LT'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'

#detailing
ax1.set_xlabel('LT (hour)')
ax1.set_ylabel("(W m$^{-2}$)")

ax1.tick_params(axis='x')
ax1.tick_params(axis='y')

ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.xaxis.set_minor_locator(AutoMinorLocator())
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)

ax1.set_xlim(6,18)
ax1.set_xticks([6,8,10,12,14,16,18],[6,8,10,12,14,16,18])
ax1.set_ylim(0,600)

#Export statistics
json_string = json.dumps(statistics)

# Open a file for writing
with open(f'{outdir}/Figure_2a_statistics.txt', 'w') as f:
    # Write the JSON string to the file
    f.write(json_string)

#=================
# Figure 2b
#=================

statistics = {}

idx = np.where(all_heights == 50)

#OBS
ax2.plot(x_OBS_30,SCu_tower_mean['NEE_flux'][idx][0],linewidth=axiswidth*2,color='green',ls="",marker='o',markersize=5,markeredgecolor='k',markeredgewidth=0.5,label='OBS')
ax2.fill_between(x_OBS_30,SCu_tower_mean['NEE_flux'][idx][0]-SCu_tower_std['NEE_flux'][idx][0],SCu_tower_mean['NEE_flux'][idx][0]+SCu_tower_std['NEE_flux'][idx][0],color='green',alpha=0.2)

#LES
wCO2_s_2 = profiles.wsv001t.values[:,0]/1000 #ppb m/s to ppm m/s
V_standard = (8.314 * (profiles.thl[:,0]))/profiles.presh[:,0] #(8.314 * (25+273.15))/1013.25e2 #m3

wCO2_prof_1D = (1/V_standard)*wCO2_s_2 #concentration (mol/m3) * umol/mol becomes umol/m2 s
wCO2_AGS = wCO2_s*(1000/44.009)

ax2.plot(x_DALES,wCO2_prof_1D,linewidth=axiswidth*3,color='darkgreen',ls="-",label='DALES')
# ax2.plot(x_DALES,wCO2_AGS,linewidth=axiswidth*3,color='tab:green',ls="--")

ax2.axhline(0,color='k',linewidth=axiswidth*2,linestyle='--')


#Statistics

#NEE
x = SCu_tower_mean['NEE_flux'][idx][0]
y_old = wCO2_prof_1D

y = []
for time in x_OBS_30:
    index = find_closest(np.array(x_DALES),time)
    y.append(y_old[index])

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()

R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['NEE'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'

#After 10 LT
x = SCu_tower_mean['NEE_flux'][idx][0]
y_old = wCO2_prof_1D

y = []
for time in x_OBS_30:
    index = find_closest(np.array(x_DALES),time)
    y.append(y_old[index])

idx_10 = find_closest(np.array(x_OBS_30),10)
x = x[idx_10:]
y = y[idx_10:]

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()

R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['NEE-10LT'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'
    

#detailing
ax2.set_xlabel('LT (hour)')
ax2.set_ylabel(r'$\overline{w^{\prime}CO_2^{\prime}}$ ($\mu$mol m$^{-2}$ s$^{-1}$)')

ax2.tick_params(axis='x')
ax2.tick_params(axis='y')
ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.xaxis.set_minor_locator(AutoMinorLocator())
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)


ax2.set_xlim(6,18)
ax2.set_xticks([6,8,10,12,14,16,18],[6,8,10,12,14,16,18])
ax2.set_ylim(-30,10)

#export statistics
json_string = json.dumps(statistics)

# Open a file for writing
with open(f'{outdir}/Figure_2b_statistics.txt', 'w') as f:
    # Write the JSON string to the file
    f.write(json_string)


#=================
# Figure 2c
#=================

legend_long = ['Cirrus','Shallow Cumulus','Altocumulus','Altostratus','Cirrusstratus','Cumulus Congestus','Deep']
legend = ['Ci','SCu','Ac','As','Cs','Cu_con','Cb']
colors = ['tab:blue','tab:blue','tab:blue','tab:blue','tab:blue','tab:blue','tab:blue']
bin = 12
cmap_own = LinearSegmentedColormap.from_list('my_list', colors, N=bin)

statistics = {}

#OBS
levels = [0.1,1,2,3,4,5,6,7]
clouds = classification_mean.T
clouds_cut = clouds[:200,:]
mask = ~np.isnan(clouds_cut)
cb_values = []
ct_values = []

for col in range(clouds_cut.shape[1]):  # Iterate through each column
    row_indices = np.where(mask[:, col])[0]  # Get indices of non-NaN values
    if len(row_indices) > 0:
        cb_values.append(SCu_mean.z.values[:200][row_indices[-1]])
        ct_values.append(SCu_mean.z.values[:200][row_indices[0]])
    else:
        cb_values.append(np.nan)
        ct_values.append(np.nan)


ax3.contourf(x_OBS_1,SCu_mean.z.values,clouds,levels=levels,cmap=cmap_own,alpha=0.8)
ax3.fill_between(x_DALES,[-1]*len(x_DALES),[-2]*len(x_DALES),color='tab:blue',alpha=0.8,label='OBS MIRA cloud layer')

ax3.plot(x_OBS_5[:-1],abl_SCu_mean,linewidth=axiswidth*3,color='tab:red',ls="-",marker='',markersize=3,label='OBS Ceilometer z$_i$')
ax3.fill_between(x_OBS_5[:-1],abl_SCu_mean-abl_SCu_std,abl_SCu_mean+abl_SCu_std,color='tab:red',alpha=0.2)

ax3.plot(x_OBS_3,np.nanmean(SCu_values_h,axis=0),linewidth=axiswidth*2,color='tab:red',ls="",marker='^',markersize=8,markeredgecolor='k',markeredgewidth=0.8,label='OBS Radiosonde z$_i$')

#LES
ax3.plot(x_DALES[:122],abl_1D[:122],linewidth=axiswidth*3,color='k',ls="-")

ax3.fill_between(x_DALES[:122],abl_1D[:122],tmser_zc_max[:122],color='tab:purple',alpha=0.3,label='DALES cloud layer')

ax3.axhline(-35,color='k',linewidth=axiswidth*3,linestyle='-',label='DALES z$_i$')

#Statistics

#ABL
x = abl_SCu_mean
y_old = abl_1D

y = []
for time in x_OBS_5[:-1]:
    index = find_closest(np.array(x_DALES),time)
    y.append(y_old[index])

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()
R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['ABL'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'

#After 10LT
x = abl_SCu_mean
y_old = abl_1D

y = []
for time in x_OBS_5[:-1]:
    index = find_closest(np.array(x_DALES),time)
    y.append(y_old[index])

idx_10 = find_closest(np.array(x_OBS_5[:-1]),10)
x = abl_SCu_mean[idx_10:]
y = y[idx_10:]

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()
R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['ABL-10LT'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'


#detailing
ax3.set_xlabel('LT (hour)')
ax3.set_ylabel('z (m)')

ax3.tick_params(axis='x')
ax3.tick_params(axis='y')

ax3.yaxis.set_minor_locator(AutoMinorLocator())
ax3.xaxis.set_minor_locator(AutoMinorLocator())
ax3.spines['right'].set_visible(False)
ax3.spines['top'].set_visible(False)

ax3.set_xlim(5.8,18.2)
ax3.set_xticks([6,8,10,12,14,16,18],[6,8,10,12,14,16,18])
ax3.set_ylim(0,3000)

#export statistics
json_string = json.dumps(statistics)

# Open a file for writing
with open(f'{outdir}/Figure_2c_statistics.txt', 'w') as f:
    # Write the JSON string to the file
    f.write(json_string)
    
#=================
# Figure 2d
#=================

statistics = {}

#Plot LES
LES_cloud_cover = pd.read_csv(f"C:/Users/.../Cloud_fraction_timeseries.csv")
x_plot = LES_cloud_cover['LT'].values

ax4.plot(x_plot,LES_cloud_cover['Condition 1'].values,linewidth=axiswidth*3,color='k',ls="-",label='DALES $a_{c}$')

ax4.plot(x_plot,LES_cloud_cover['Condition 2'].values,linewidth=axiswidth*3,color='k',ls="--",label='DALES $a_{c,w}$')

ax4.plot(x_plot,LES_cloud_cover['Condition 3'].values,linewidth=axiswidth*3,color='k',ls=":",label='DALES $a_{cc}$')

#Detailing
ax4.set_yticks(np.arange(0,0.5,0.1),np.round(np.arange(0,0.5,0.1),1))
ax4.set_ylim(0,0.3)
ax4.set_xticks(np.arange(6,20,2),np.arange(6,20,2))
ax4.set_xlim(6,18)
ax4.set_xlabel('LT (hour)')

ax4.tick_params(axis='x')
ax4.tick_params(axis='y')
ax4.spines['right'].set_visible(False)
ax4.spines['top'].set_visible(False)

ax4.yaxis.set_minor_locator(AutoMinorLocator())
ax4.xaxis.set_minor_locator(AutoMinorLocator())
ax4.tick_params(which='major')
ax4.tick_params(which='minor')

ax4.set_ylabel('Cloud Cover (-)')


#=======================================================================================================

#Total detailing

ax1.text(-0.1,1.1, '(a)',transform=ax1.transAxes,fontsize=plt.rcParams['font.size']*1.2)
ax2.text(-0.1,1.1, '(b)',transform=ax2.transAxes,fontsize=plt.rcParams['font.size']*1.2)
ax3.text(-0.1,1.1, '(c)',transform=ax3.transAxes,fontsize=plt.rcParams['font.size']*1.2)
ax4.text(-0.1,1.1, '(d)',transform=ax4.transAxes,fontsize=plt.rcParams['font.size']*1.2)

ax1.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax2.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax3.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax4.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)

plt.subplots_adjust(wspace=0.4,hspace=0.4)


ax1.legend(loc='lower left', bbox_to_anchor=(0.05, 1.1), ncol=2,
           borderaxespad=0, frameon=False, markerscale=1.6)

ax2.legend(loc='lower left', bbox_to_anchor=(0.05, 1.15), ncol=2,
           borderaxespad=0, frameon=False, markerscale=2)

ax3.legend(loc='lower left', bbox_to_anchor=(-0.2, -.55), ncol=2,
           borderaxespad=0, frameon=False, markerscale=2)

ax4.legend(loc='lower left', bbox_to_anchor=(0.05, -.50), ncol=2,
              borderaxespad=0, frameon=False, markerscale=2)

#Export
plt.savefig(f'{outdir}/Figure_2.jpg',bbox_inches='tight',dpi=500)
plt.savefig(f'{outdir}/Figure_2.svg',bbox_inches='tight')
plt.savefig(f'{outdir}/Figure_2.pdf',bbox_inches='tight')
plt.show()


#%%


#=================
# Figure 3 (total)
#=================
fig,ax = plt.subplots(2,3,dpi=72,figsize=(8.27, 11.69/1.5))

ax1, ax2, ax3, ax4, ax5, ax6 = ax.flatten()

#Delete not used axes
fig.delaxes(ax5)
fig.delaxes(ax6)

#Changes positions
ax1.set_position([0.05, 0.55, 0.25, 0.4])  # [left, bottom, width, height]
ax2.set_position([0.45, 0.55, 0.25, 0.4])  # [left, bottom, width, height]
ax3.set_position([0.85, 0.55, 0.25, 0.4])  # [left, bottom, width, height]
ax4.set_position([0.05, -0.07, 0.85+0.2, 0.4])  # [left, bottom, width, height]


#=================
# Figure 3 a
#=================

times_total = [6,9,12,15,18]
times = [9,12,15]

sel = times_total.index(12)
idx = find_closest(np.array(x_DALES),12)
z = SCu_values_z_ref[0][0]

values = np.array(SCu_values_theta_ref)[:,sel,:]

#OBS
idx = find_closest(np.array(x_OBS_5[:-1]),12)
ax1.axhline(abl_SCu_mean[idx],color='gray',linewidth=axiswidth*1.5,linestyle='--')

idx_1 = find_closest(np.array(x_OBS_1),11.90)
idx_2 = find_closest(np.array(x_OBS_1),12.10)

ct_value = np.nanmean(ct_values[idx_1:idx_2+1])+np.nanstd(ct_values[idx_1:idx_2+1])

ax1.axhline(ct_value,color='gray',linewidth=axiswidth*1.5,linestyle='-.')

ax1.plot(np.nanmean(values,axis=0),z,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=3)
ax1.fill_betweenx(z,np.nanmean(values,axis=0)-np.nanstd(values,axis=0),np.nanmean(values,axis=0)+np.nanstd(values,axis=0),color='tab:red',alpha=0.2)

#DALES
idx = find_closest(np.array(x_DALES),12)
ax1.axhline(abl_1D[idx],color='tab:green',linewidth=axiswidth*1.5,linestyle='--')

ax1.fill_between(np.arange(300,320,5),abl_1D[idx],tmser_zc_max[idx],color='k',alpha=0.05)

ax1.plot(profiles_theta[idx],profiles_zt,linewidth=axiswidth*5,color='w',ls="-")
ax1.plot(profiles_theta[idx],profiles_zt,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')


#detailing
ax1.set_xlabel(r'$\theta$ (K)')
ax1.set_ylabel('z (m)')

ax1.set_ylim(0,3500)
ax1.set_yticks(np.arange(0,4000,500),np.arange(0,4000,500))
ax1.set_xlim(300,315)

ax1.tick_params(axis='x')
ax1.tick_params(axis='y')
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)

ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.xaxis.set_minor_locator(AutoMinorLocator())
ax1.set_title('12 LT',y=1.02)

statistics = {}

x = np.nanmean(values,axis=0)
y_old = profiles_theta[idx]

y = []
for height in SCu_values_z_ref[0][0]: 
    idx = find_closest(profiles_zt,height)
    y.append(y_old[idx])

cut = find_closest(np.array(SCu_values_z_ref[0][0]),3500)
x = x[:cut]
y = y[:cut]

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()

R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['theta'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'

#export

json_string = json.dumps(statistics)

# Open a file for writing
with open(f'{outdir}/Figure_3a_statistics.txt', 'w') as f:
    # Write the JSON string to the file
    f.write(json_string)



#=================
#   Figure 3 b
#=================

values = np.array(SCu_values_q_ref)[:,sel,:]

#OBS
idx = find_closest(np.array(x_OBS_5[:-1]),12)
ax2.axhline(abl_SCu_mean[idx],color='gray',linewidth=axiswidth*2,linestyle='--')

idx_1 = find_closest(np.array(x_OBS_1),11.90)
idx_2 = find_closest(np.array(x_OBS_1),12.10)

ct_value = np.nanmean(ct_values[idx_1:idx_2+1])+np.nanstd(ct_values[idx_1:idx_2+1])

ax2.axhline(ct_value,color='gray',linewidth=axiswidth*1.5,linestyle='-.')

ax2.plot(np.nanmean(values,axis=0),z,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=3)
ax2.fill_betweenx(z,np.nanmean(values,axis=0)-np.nanstd(values,axis=0),np.nanmean(values,axis=0)+np.nanstd(values,axis=0),color='tab:red',alpha=0.2)

(_, caps, _) = ax2.errorbar(flight2_low_mean_h2o, 200, xerr=flight2_low_std_h2o, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10)

for cap in caps:
    cap.set_markeredgewidth(3)

(_, caps, _) = ax2.errorbar(flight2_middle_mean_h2o, 1100, xerr=flight2_middle_std_h2o, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10)

for cap in caps:
    cap.set_markeredgewidth(3)

(_, caps, _) = ax2.errorbar(flight2_high_mean_h2o, 3000, xerr=flight2_high_std_h2o, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10)

for cap in caps:
    cap.set_markeredgewidth(3)

#DALES
idx = find_closest(np.array(x_DALES),12)
ax2.axhline(abl_1D[idx],color='tab:green',linewidth=axiswidth*1.5,linestyle='--')

ax2.fill_between(np.arange(0,25,5),abl_1D[idx],tmser_zc_max[idx],color='k',alpha=0.05)

ax2.plot(profiles_q[idx],profiles_zt,linewidth=axiswidth*5,color='w',ls="-")
ax2.plot(profiles_q[idx],profiles_zt,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')

#detailing
ax2.set_xlabel('$q$ (g kg$^{-1}$)')

ax2.set_xlim(0,20)
ax2.set_xticks(np.arange(0,25,5),np.arange(0,25,5))
ax2.set_ylim(0,3500)
ax2.set_yticks(np.arange(0,4000,500),np.arange(0,4000,500))

ax2.tick_params(axis='x')
ax2.tick_params(axis='y')
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)

ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.xaxis.set_minor_locator(AutoMinorLocator())

ax2.set_title('12 LT',y=1.02)

statistics = {}

x = np.nanmean(values,axis=0)
y_old = profiles_q[idx]

y = []
for height in SCu_values_z_ref[0][0]: 
    idx = find_closest(profiles_zt,height)
    y.append(y_old[idx])

cut = find_closest(np.array(SCu_values_z_ref[0][0]),3500)
x = x[:cut]
y = y[:cut]

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()

R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['q'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'

#export

json_string = json.dumps(statistics)

# Open a file for writing
with open(f'{outdir}/Figure_3b_statistics.txt', 'w') as f:
    # Write the JSON string to the file
    f.write(json_string)

#=================
#   Figure 3 c
#=================

#CO2

#Flight 2

#OBS
idx = find_closest(np.array(x_OBS_5[:-1]),13)
ax3.axhline(abl_SCu_mean[idx],color='gray',linewidth=axiswidth*1.5,linestyle='--')

abl_value = abl_SCu_mean[idx]
idx_1 = find_closest(np.array(x_OBS_1),12.90)
idx_2 = find_closest(np.array(x_OBS_1),13.10)

ct_value = np.nanmean(ct_values[idx_1:idx_2+1])+np.nanstd(ct_values[idx_1:idx_2+1])

ax3.axhline(ct_value,color='gray',linewidth=axiswidth*1.5,linestyle='-.')

ax3.plot(np.nanmean(values,axis=0),z,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=3)
ax3.fill_betweenx(z,np.nanmean(values,axis=0)-np.nanstd(values,axis=0),np.nanmean(values,axis=0)+np.nanstd(values,axis=0),color='tab:red',alpha=0.2)

ax3.plot(flight2_all_co2,np.array(flight2_all_z)-120,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax3.plot([],[],linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=8,label='OBS')
ax3.plot(400,-200,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=10,markeredgecolor='k',markeredgewidth=0.05)

(_, caps, _) = ax3.errorbar(flight2_low_mean_co2, 200, xerr=flight2_low_std_co2, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10)

ax3.plot([],[],linewidth=axiswidth*2,color='tab:blue',ls="-",marker='^',markersize=10,label='OBS horiz. cross section')


for cap in caps:
    cap.set_markeredgewidth(3)

(_, caps, _) = ax3.errorbar(flight2_middle_mean_co2, 1100, xerr=flight2_middle_std_co2, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10)

for cap in caps:
    cap.set_markeredgewidth(3)

(_, caps, _) = ax3.errorbar(flight2_high_mean_co2, 3000, xerr=flight2_high_std_co2, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10)

for cap in caps:
    cap.set_markeredgewidth(3)

#Plot tower data
idx = find_closest(np.array(x_OBS_1),13)

ax3.plot(SCu_CO2_mean['24m'][idx],24,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax3.plot(SCu_CO2_mean['38m'][idx],38,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax3.plot(SCu_CO2_mean['53m'][idx],53,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax3.plot(SCu_CO2_mean['79m'][idx],79,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax3.plot(SCu_CO2_mean['321m'][idx],321,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)

#DALES
idx = find_closest(np.array(x_DALES),13)

ax3.axhline(abl_1D[idx],color='tab:green',linewidth=axiswidth*1.5,linestyle='--')

ax3.plot(profiles_CO2[idx],profiles_zt,linewidth=axiswidth*5,color='w',ls="-")
ax3.plot(profiles_CO2[idx],profiles_zt,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')

ax3.fill_between(np.arange(414,424,2),abl_1D[idx],tmser_zc_max[idx],color='k',alpha=0.05,label='DALES cloud layer')

#Detailing
ax3.set_xlabel('CO$_2$ (ppm)')
ax3.set_xlim(414,422)
ax3.set_ylim(0,3500)

ax3.set_xticks(np.arange(414,424,2),np.arange(414,424,2))
ax3.set_yticks(np.arange(0,4000,500),np.arange(0,4000,500))

ax3.tick_params(axis='x')
ax3.tick_params(axis='y')
ax3.spines['right'].set_visible(False)
ax3.spines['top'].set_visible(False)

ax3.yaxis.set_minor_locator(AutoMinorLocator())
ax3.xaxis.set_minor_locator(AutoMinorLocator())
ax3.set_title('13 LT',y=1.02)

ax3.text(423,abl_value-20,'OBS z$_i$',color='k',fontsize=plt.rcParams['font.size'])
ax3.text(423,abl_1D[idx]-200,'DALES z$_i$',color='tab:green',fontsize=plt.rcParams['font.size'])
ax3.text(423,ct_value-100,'OBS z$_{ctop}$',color='k',fontsize=plt.rcParams['font.size'])

statistics = {}

x = flight2_all_co2
y_old = profiles_CO2[idx]

y = []
for height in np.array(flight2_all_z)-120: 
    idx = find_closest(profiles_zt,height)
    y.append(y_old[idx])

cut = find_closest(np.array(np.array(flight2_all_z)-120),3500)
x = x[:cut]
y = y[:cut]

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()

R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['CO2'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'

#export

json_string = json.dumps(statistics)

# Open a file for writing
with open(f'{outdir}/Figure_3c_statistics.txt', 'w') as f:
    # Write the JSON string to the file
    f.write(json_string)


#=================
# Figure 3 d
#=================

statistics = {}

idx = np.where(all_heights == 50)

#OBS
colors = ['tab:gray','tab:blue','tab:orange','#D8B40C','tab:red']
p = 0
for height in [24, 38, 53, 79, 321]:
    height_str = f'{height}m'
    
    ax4.plot(x_OBS_1,SCu_CO2_mean[height_str],linewidth=axiswidth*2,color=colors[p],ls="",marker='o',markersize=8,markeredgecolor='k',markeredgewidth=0.5,label=f'OBS {height} m')
    ax4.fill_between(x_OBS_1,SCu_CO2_mean[height_str]-SCu_CO2_std[height_str],SCu_CO2_mean[height_str]+SCu_CO2_std[height_str],color=colors[p],alpha=0.2)
    
    p += 1
    
    
#LES
idx = find_closest(np.array(profiles_zt),321)
ax4.plot(x_DALES,profiles_CO2[:,idx],linewidth=axiswidth*5,color='w',ls="-")
ax4.plot(x_DALES,profiles_CO2[:,idx],linewidth=axiswidth*2,color='tab:green',ls="-",label='DALES 321 m')

#Statistics
x = SCu_CO2_mean['321m']
y_old = profiles_CO2[:,idx]

y = []
for time in x_OBS_1:
    index = find_closest(np.array(x_DALES),time)
    y.append(y_old[index])

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]

res = sm.OLS(y,x).fit()

R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['321m'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'

#After 10 LT
x = SCu_CO2_mean['321m']
y_old = profiles_CO2[:,idx]

y = []
for time in x_OBS_1:
    index = find_closest(np.array(x_DALES),time)
    y.append(y_old[index])

idx_10 = find_closest(np.array(x_OBS_1),10)
x = x[idx_10:]
y = y[idx_10:]

nan_mask_x = np.isnan(x)

x = x[~nan_mask_x]
y = np.array(y)[~nan_mask_x]

nan_mask_y = np.isnan(y)

x = x[~nan_mask_y]
y = y[~nan_mask_y]


res = sm.OLS(y,x).fit()

R_squared = res.rsquared

MSE =  np.square(np.subtract(x,y))
MSE_std = np.nanstd(MSE)
MSE_mean = np.nanmean(MSE)

RMSE = np.sqrt(MSE)
RMSE_std = np.nanstd(RMSE)
RMSE_mean = np.nanmean(RMSE)

IOA = index_agreement(x,y)

statistics['321m-10LT'] = f'R-squared: {R_squared}, MSE: {MSE_mean} ({MSE_std}), RMSE: {RMSE_mean} ({RMSE_std}), IOA: {IOA}'


#detailing
ax4.set_xlabel('LT (hour)')
ax4.set_ylabel('CO$_2$ (ppm)')

ax4.tick_params(axis='x')
ax4.tick_params(axis='y')

ax4.yaxis.set_minor_locator(AutoMinorLocator())
ax4.xaxis.set_minor_locator(AutoMinorLocator())
ax4.spines['right'].set_visible(False)
ax4.spines['top'].set_visible(False)

ax4.set_xlim(5.8,18.2)
ax4.set_xticks([6,8,10,12,14,16,18],[6,8,10,12,14,16,18])
ax4.set_ylim(410,480)  


#Export statistics
json_string = json.dumps(statistics)

# Open a file for writing
with open(f'{outdir}/Figure_3d_statistics.txt', 'w') as f:
    # Write the JSON string to the file
    f.write(json_string)

#=====================================================================

#Total detailing
ax3.legend(loc='lower left', bbox_to_anchor= (-3.2, 1.3), ncol=4,
                borderaxespad=0, frameon=False,markerscale=1)

ax4.legend(loc='lower left', bbox_to_anchor=(.5, 0.7), ncol=2,
           borderaxespad=0, frameon=False, markerscale=1)

ax1.text(-0.1,1.1, '(a)',transform=ax1.transAxes,fontsize=plt.rcParams['font.size']*1.2)
ax2.text(-0.1,1.1, '(b)',transform=ax2.transAxes,fontsize=plt.rcParams['font.size']*1.2)
ax3.text(-0.1,1.1, '(c)',transform=ax3.transAxes,fontsize=plt.rcParams['font.size']*1.2)
ax4.text(-0.02,1.1, '(d)',transform=ax4.transAxes,fontsize=plt.rcParams['font.size']*1.2)

ax1.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax2.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax3.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax4.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)


plt.savefig(f'{outdir}/Figure_3.jpg',bbox_inches='tight',dpi=500)
plt.savefig(f'{outdir}/Figure_3.svg',bbox_inches='tight')
plt.savefig(f'{outdir}/Figure_3.pdf',bbox_inches='tight')
plt.show()
plt.close()

#%%
#===========================================
# Supplements - Respiration and Net Radiation
#===========================================

fig, ax = plt.subplots(1,2,dpi=72,figsize=(8.27*1.2, 11.69/3))

ax1, ax2 = ax.flatten()


#=================
#    Figure a
#=================

#OBS
ax1.plot(x_OBS_30,SCu_Resp_mean,linewidth=axiswidth*2,color='tab:red',marker='o',ls="",markeredgecolor='k',markersize=7,markeredgewidth=1,label='OBS')
ax1.fill_between(x_OBS_30,SCu_Resp_mean-SCu_Resp_std,SCu_Resp_mean+SCu_Resp_std,color='tab:red',alpha=0.2)

#DALES
ax1.plot(x_DALES,DALES_Resp_AGS,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')

#Detailing
ax1.set_xlabel('LT (hour)')
ax1.set_ylabel(r'Resp ($\mu$mol m$^{-2}$ s$^{-1}$)')

ax1.tick_params(axis='x')
ax1.tick_params(axis='y')

ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax1.xaxis.set_minor_locator(AutoMinorLocator())
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)

ax1.set_xlim(6,18)
ax1.set_xticks([6,8,10,12,14,16,18],[6,8,10,12,14,16,18])

ax1.set_ylim(4.5,6.5)

ax1.text(-0.1,1.1, '(a)',transform=ax1.transAxes,fontsize=plt.rcParams['font.size']*1.2)

ax1.legend(loc='lower left', bbox_to_anchor=(0.65, 1.2), ncol=2,
           borderaxespad=0, frameon=False, markerscale=1)

#=================
#    Figure b
#=================

#OBS
ax2.plot(x_OBS_10,SCu_rad_mean['Qnet'],linewidth=axiswidth*2,color='tab:red',marker='o',ls="",markeredgecolor='k',markersize=7,markeredgewidth=1,label='OBS')
ax2.fill_between(x_OBS_10,SCu_rad_mean['Qnet']-SCu_rad_std['Qnet'],SCu_rad_mean['Qnet']+SCu_rad_std['Qnet'],color='tab:red',alpha=0.2)

#DALES
idx = find_closest(np.array(profiles_zm),75)
DALES_rad = profiles_Qnet[:,idx]
ax2.plot(x_DALES,DALES_rad,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')

#Detailing
ax2.set_xlabel('LT (hour)')
ax2.set_ylabel(r'Q$_{net}$ (W m$^{-2}$)')

ax2.tick_params(axis='x')
ax2.tick_params(axis='y')

ax2.yaxis.set_minor_locator(AutoMinorLocator())
ax2.xaxis.set_minor_locator(AutoMinorLocator())
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)

ax2.set_xlim(6,18)
ax2.set_xticks([6,8,10,12,14,16,18],[6,8,10,12,14,16,18])

ax2.set_ylim(-100,1000)

ax2.text(-0.1,1.1, '(b)',transform=ax2.transAxes,fontsize=plt.rcParams['font.size']*1.2)

plt.subplots_adjust(wspace=0.5)

ax1.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax2.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)

#Save
plt.savefig(f'{outdir}/Figure_S1.jpg',bbox_inches='tight',dpi=500)
plt.savefig(f'{outdir}/Figure_S1.svg',bbox_inches='tight')
plt.savefig(f'{outdir}/Figure_S1.pdf',bbox_inches='tight')

plt.show()
plt.close()

#%%
#===========================================
# Supplements - Vertical Profiles (theta, q and CO2)
#===========================================

fig, ax = plt.subplots(3,5,dpi=72,figsize=(8.27, 11.69/2),sharey=True)

ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9, ax10, ax11, ax12, ax13, ax14, ax15 = ax.flatten()

#Set position
ax1.set_position([0.05, 0.55, 0.15, 0.4])  # [left, bottom, width, height]
ax2.set_position([0.25, 0.55, 0.15, 0.4])  # [left, bottom, width, height]
ax3.set_position([0.45, 0.55, 0.15, 0.4])  # [left, bottom, width, height]
ax4.set_position([0.65, 0.55, 0.15, 0.4])  # [left, bottom, width, height]
ax5.set_position([0.85, 0.55, 0.15, 0.4])  # [left, bottom, width, height]

ax6.set_position([0.05, -0.12, 0.15, 0.4])  # [left, bottom, width, height]
ax7.set_position([0.25, -0.12, 0.15, 0.4])  # [left, bottom, width, height]
ax8.set_position([0.45, -0.12, 0.15, 0.4])  # [left, bottom, width, height]
ax9.set_position([0.65, -0.12, 0.15, 0.4])  # [left, bottom, width, height]
ax10.set_position([0.85, -0.12, 0.15, 0.4])  # [left, bottom, width, height]

ax1.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax2.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax3.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax4.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax5.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)

ax6.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax7.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax8.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax9.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax10.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)


#Turn axes 13 and 15 off
fig.delaxes(ax13)
fig.delaxes(ax15)

ax11.set_position([0.25, -0.82, 0.15, 0.4])
ax12.set_position([0.45, -0.82, 0.15, 0.4])  # [left, bottom, width, height]
ax14.set_position([0.65, -0.82, 0.15, 0.4])  # [left, bottom, width, height]

ax11.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax12.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)
ax14.grid(linestyle=':',color='tab:gray',alpha=0.3,linewidth=axiswidth*2)

times = [6,9,12,15,18]

#Theta - top row

for i in range(5):
    
    sel = times.index(times[i])
    idx = find_closest(np.array(x_DALES),times[i])
    z = SCu_values_z_ref[0][0]
    
    values = np.array(SCu_values_theta_ref)[:,sel,:]
    
    #OBS
    locals()[f'ax{i+1}'].plot(np.nanmean(values,axis=0),z,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=3,label='OBS')
    locals()[f'ax{i+1}'].fill_betweenx(z,np.nanmean(values,axis=0)-np.nanstd(values,axis=0),np.nanmean(values,axis=0)+np.nanstd(values,axis=0),color='tab:red',alpha=0.2)
    
    #DALES
    idx = find_closest(np.array(x_DALES),times[i])
    locals()[f'ax{i+1}'].plot(profiles_theta[idx],profiles_zt,linewidth=axiswidth*5,color='w',ls="-")
    locals()[f'ax{i+1}'].plot(profiles_theta[idx],profiles_zt,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')
    
    #detailing
    locals()[f'ax{i+1}'].set_xlabel(r'$\theta$ (K)')
    ax1.set_ylabel('z (m)')
    ax1.text(-0.9,1.1, '(a)',transform=ax1.transAxes,fontsize=plt.rcParams['font.size']*1.2)
    
    locals()[f'ax{i+1}'].set_ylim(0,3500)
    locals()[f'ax{i+1}'].set_yticks(np.arange(0,4000,500),np.arange(0,4000,500))
    
    locals()[f'ax{i+1}'].set_xlim(295,315)
    locals()[f'ax{i+1}'].set_xticks(np.arange(295,320,10),np.arange(295,320,10))
    
    locals()[f'ax{i+1}'].tick_params(axis='x')
    locals()[f'ax{i+1}'].tick_params(axis='y')
    locals()[f'ax{i+1}'].spines['right'].set_visible(False)
    locals()[f'ax{i+1}'].spines['top'].set_visible(False)
    
    locals()[f'ax{i+1}'].yaxis.set_minor_locator(AutoMinorLocator())
    locals()[f'ax{i+1}'].xaxis.set_minor_locator(AutoMinorLocator())
    locals()[f'ax{i+1}'].set_title(f'{times[i]} LT',y=1.02)

#q - middle row
for i in range(5):
    
    sel = times.index(times[i])
    idx = find_closest(np.array(x_DALES),times[i])
    z = SCu_values_z_ref[0][0]
    
    values = np.array(SCu_values_q_ref)[:,sel,:]
    
    #OBS
    locals()[f'ax{i+6}'].plot(np.nanmean(values,axis=0),z,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=3)
    locals()[f'ax{i+6}'].fill_betweenx(z,np.nanmean(values,axis=0)-np.nanstd(values,axis=0),np.nanmean(values,axis=0)+np.nanstd(values,axis=0),color='tab:red',alpha=0.2)
    
    #DALES
    idx = find_closest(np.array(x_DALES),times[i])
    locals()[f'ax{i+6}'].plot(profiles_q[idx],profiles_zt,linewidth=axiswidth*5,color='w',ls="-")
    locals()[f'ax{i+6}'].plot(profiles_q[idx],profiles_zt,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')
    
    #detailing
    locals()[f'ax{i+6}'].set_xlabel('$q$ (g kg$^{-1}$)')
    ax6.set_ylabel('z (m)')
    ax6.text(-0.9,1.1, '(b)',transform=ax6.transAxes,fontsize=plt.rcParams['font.size']*1.2)
    
    locals()[f'ax{i+6}'].set_ylim(0,3500)
    locals()[f'ax{i+6}'].set_yticks(np.arange(0,4000,500),np.arange(0,4000,500))
    
    locals()[f'ax{i+6}'].set_xlim(0,20)
    locals()[f'ax{i+6}'].set_xticks(np.arange(0,25,5),np.arange(0,25,5))
    
    locals()[f'ax{i+6}'].tick_params(axis='x')
    locals()[f'ax{i+6}'].tick_params(axis='y')
    locals()[f'ax{i+6}'].spines['right'].set_visible(False)
    locals()[f'ax{i+6}'].spines['top'].set_visible(False)
    
    locals()[f'ax{i+6}'].yaxis.set_minor_locator(AutoMinorLocator())
    locals()[f'ax{i+6}'].xaxis.set_minor_locator(AutoMinorLocator())
    locals()[f'ax{i+6}'].set_title(f'{times[i]} LT',y=1.02)


#CO2 - bottom row

#6 LT

#OBS
ax11.plot([np.nan],[np.nan],linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=7,label='OBS')

(_, caps, _) = ax11.errorbar([np.nan], 200, xerr=[np.nan], 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10,label='OBS Horiz. cross section')

#DALES
idx = find_closest(np.array(x_DALES),6)

ax11.plot(profiles_CO2[idx],profiles_zt,linewidth=axiswidth*5,color='w',ls="-")
ax11.plot(profiles_CO2[idx],profiles_zt,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')

#detailing
ax11.set_xlabel('CO$_2$ (ppm)')
ax11.set_ylabel('z (m)')
ax11.set_xlim(410,450)
ax11.set_xticks(np.arange(410,460,10),np.arange(410,460,10),rotation=45)

ax11.set_ylim(0,3500)
ax11.set_yticks(np.arange(0,4000,500),np.arange(0,4000,500))

ax11.tick_params(axis='x')
ax11.tick_params(axis='y')

ax11.spines['right'].set_visible(False)
ax11.spines['top'].set_visible(False)

ax11.yaxis.set_minor_locator(AutoMinorLocator())
ax11.xaxis.set_minor_locator(AutoMinorLocator())

ax11.set_title('6 LT',y=1.02)

ax11.text(-0.9,1.1, '(c)',transform=ax11.transAxes,fontsize=plt.rcParams['font.size']*1.2)

ax11.legend(loc='lower left', bbox_to_anchor=(-0.75, -0.8), ncol=3,
              borderaxespad=0, frameon=False, markerscale=1)

#flight 1 - 9 LT

#OBS
ax12.plot(flight1_all_co2,np.array(flight1_all_z)-120,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2,label='OBS')

(_, caps, _) = ax12.errorbar(flight1_low_mean_co2, 200, xerr=flight1_low_std_co2, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10,label='OBS Horiz. cross section')

for cap in caps:
    cap.set_markeredgewidth(3)

(_, caps, _) = ax12.errorbar(flight1_middle_mean_co2, 1100, xerr=flight1_middle_std_co2, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10)

for cap in caps:
    cap.set_markeredgewidth(3)

(_, caps, _) = ax12.errorbar(flight1_high_mean_co2, 3000, xerr=flight1_high_std_co2, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10)

for cap in caps:
    cap.set_markeredgewidth(3)

#Plot tower data
idx = find_closest(np.array(x_OBS_1),9)

ax12.plot(SCu_CO2_mean['24m'][idx],24,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax12.plot(SCu_CO2_mean['38m'][idx],38,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax12.plot(SCu_CO2_mean['53m'][idx],53,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax12.plot(SCu_CO2_mean['79m'][idx],79,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax12.plot(SCu_CO2_mean['321m'][idx],321,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)

#DALES
idx = find_closest(np.array(x_DALES),9)

ax12.plot(profiles_CO2[idx],profiles_zt,linewidth=axiswidth*5,color='w',ls="-")
ax12.plot(profiles_CO2[idx],profiles_zt,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')

#Detailing
ax12.set_xlabel('CO$_2$ (ppm)')
ax12.set_xlim(410,440)
ax12.set_xticks(np.arange(410,450,10),np.arange(410,450,10),rotation=45)

ax12.set_ylim(0,3500)
ax12.set_yticks(np.arange(0,4000,500),np.arange(0,4000,500))

ax12.tick_params(axis='x')
ax12.tick_params(axis='y')
ax12.spines['right'].set_visible(False)
ax12.spines['top'].set_visible(False)

ax12.yaxis.set_minor_locator(AutoMinorLocator())
ax12.xaxis.set_minor_locator(AutoMinorLocator())
ax12.set_title('9 LT',y=1.02)

#flight 2 - 12 LT

#OBS
ax14.plot(flight2_all_co2,np.array(flight2_all_z)-120,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)

(_, caps, _) = ax14.errorbar(flight2_low_mean_co2, 200, xerr=flight2_low_std_co2, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10,label='OBS Horiz. cross section')

for cap in caps:
    cap.set_markeredgewidth(3)

(_, caps, _) = ax14.errorbar(flight2_middle_mean_co2, 1100, xerr=flight2_middle_std_co2, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10)

for cap in caps:
    cap.set_markeredgewidth(3)

(_, caps, _) = ax14.errorbar(flight2_high_mean_co2, 3000, xerr=flight2_high_std_co2, 
             linewidth=3, color='tab:blue', ls="", marker='^', 
             markersize=10, markeredgecolor='k', markeredgewidth=0.5, 
             capsize=10)

for cap in caps:
    cap.set_markeredgewidth(3)

#Plot tower data
idx = find_closest(np.array(x_OBS_1),13)

ax14.plot(SCu_CO2_mean['24m'][idx],24,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax14.plot(SCu_CO2_mean['38m'][idx],38,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax14.plot(SCu_CO2_mean['53m'][idx],53,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax14.plot(SCu_CO2_mean['79m'][idx],79,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)
ax14.plot(SCu_CO2_mean['321m'][idx],321,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=2)

#DALES
idx = find_closest(np.array(x_DALES),13)

ax14.plot(profiles_CO2[idx],profiles_zt,linewidth=axiswidth*5,color='w',ls="-")
ax14.plot(profiles_CO2[idx],profiles_zt,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')

#Detailing
ax14.set_xlabel('CO$_2$ (ppm)')
ax14.set_xlim(414,422)
ax14.set_xticks(np.arange(414,424,2),np.arange(414,424,2),rotation=45)

ax14.set_ylim(0,3500)
ax14.set_yticks(np.arange(0,4000,500),np.arange(0,4000,500))

ax14.tick_params(axis='x')
ax14.tick_params(axis='y')
ax14.spines['right'].set_visible(False)
ax14.spines['top'].set_visible(False)

ax14.yaxis.set_minor_locator(AutoMinorLocator())
ax14.xaxis.set_minor_locator(AutoMinorLocator())
ax14.set_title('13 LT',y=1.02)


#Save
plt.savefig(f'{outdir}/Figure_S2.jpg',bbox_inches='tight',dpi=500)
plt.savefig(f'{outdir}/Figure_S2.svg',bbox_inches='tight')
plt.savefig(f'{outdir}/Figure_S2.pdf',bbox_inches='tight')

plt.show()
plt.close()

#%%
#===========================================
# Supplements - Vertical Profiles (u and v)
#===========================================

fig, ax = plt.subplots(2,5,dpi=72,figsize=(8.27*1.5, 11.69/1.2),sharey=True)

ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9, ax10 = ax.flatten()

times = [6,9,12,15,18]

#u - top row
for i in range(5):
    
    sel = times.index(times[i])
    idx = find_closest(np.array(x_DALES),times[i])
    z = SCu_values_z_ref[0][0]
    
    values = np.array(SCu_values_wind_u_ref)[:,sel,:]
    
    #OBS
    locals()[f'ax{i+1}'].plot(np.nanmean(values,axis=0),z,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=3,label='OBS')
    locals()[f'ax{i+1}'].fill_betweenx(z,np.nanmean(values,axis=0)-np.nanstd(values,axis=0),np.nanmean(values,axis=0)+np.nanstd(values,axis=0),color='tab:red',alpha=0.2)
    
    #DALES
    idx = find_closest(np.array(x_DALES),times[i])
    locals()[f'ax{i+1}'].plot(profiles_u[idx],profiles_zt,linewidth=axiswidth*5,color='w',ls="-")
    locals()[f'ax{i+1}'].plot(profiles_u[idx],profiles_zt,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')
    
    #detailing
    locals()[f'ax{i+1}'].set_xlabel(r'$u$ (m s$^{-1}$)')
    ax1.set_ylabel('z (m)')
    ax1.text(-0.9,1.1, '(a)',transform=ax1.transAxes,fontsize=plt.rcParams['font.size']*1.2)
    
    locals()[f'ax{i+1}'].set_ylim(0,3500)
    locals()[f'ax{i+1}'].set_yticks(np.arange(0,4000,500),np.arange(0,4000,500))
    
    locals()[f'ax{i+1}'].set_xlim(-12,2)
    locals()[f'ax{i+1}'].set_xticks(np.arange(-12,4,2),np.arange(-12,4,2),rotation=45)
    
    locals()[f'ax{i+1}'].tick_params(axis='x')
    locals()[f'ax{i+1}'].tick_params(axis='y')
    locals()[f'ax{i+1}'].spines['right'].set_visible(False)
    locals()[f'ax{i+1}'].spines['top'].set_visible(False)
    
    locals()[f'ax{i+1}'].yaxis.set_minor_locator(AutoMinorLocator())
    locals()[f'ax{i+1}'].xaxis.set_minor_locator(AutoMinorLocator())
    locals()[f'ax{i+1}'].set_title(f'{times[i]} LT',y=1.02)
    locals()[f'ax{i+1}'].axvline(x=0, color='tab:gray', linestyle=':', linewidth=axiswidth*2)

#v - top row
for i in range(5):
    
    sel = times.index(times[i])
    idx = find_closest(np.array(x_DALES),times[i])
    z = SCu_values_z_ref[0][0]
    
    values = np.array(SCu_values_wind_v_ref)[:,sel,:]
    
    #OBS
    locals()[f'ax{i+6}'].plot(np.nanmean(values,axis=0),z,linewidth=axiswidth*2,color='tab:red',ls="",marker='o',markersize=3,label='OBS')
    locals()[f'ax{i+6}'].fill_betweenx(z,np.nanmean(values,axis=0)-np.nanstd(values,axis=0),np.nanmean(values,axis=0)+np.nanstd(values,axis=0),color='tab:red',alpha=0.2)
    
    #DALES
    idx = find_closest(np.array(x_DALES),times[i])
    locals()[f'ax{i+6}'].plot(profiles_v[idx],profiles_zt,linewidth=axiswidth*5,color='w',ls="-")
    locals()[f'ax{i+6}'].plot(profiles_v[idx],profiles_zt,linewidth=axiswidth*3,color='tab:green',ls="-",label='DALES')
    
    #detailing
    locals()[f'ax{i+6}'].set_xlabel(r'$v$ (m s$^{-1}$)')
    ax1.set_ylabel('z (m)')
    ax1.text(-0.9,1.1, '(b)',transform=ax6.transAxes,fontsize=plt.rcParams['font.size']*1.2)
    
    locals()[f'ax{i+6}'].set_ylim(0,3500)
    locals()[f'ax{i+6}'].set_yticks(np.arange(0,4000,500),np.arange(0,4000,500))
    
    locals()[f'ax{i+6}'].set_xlim(-6,6)
    locals()[f'ax{i+6}'].set_xticks(np.arange(-6,8,2),np.arange(-6,8,2),rotation=45)
    
    locals()[f'ax{i+6}'].tick_params(axis='x')
    locals()[f'ax{i+6}'].tick_params(axis='y')
    locals()[f'ax{i+6}'].spines['right'].set_visible(False)
    locals()[f'ax{i+6}'].spines['top'].set_visible(False)
    
    locals()[f'ax{i+6}'].yaxis.set_minor_locator(AutoMinorLocator())
    locals()[f'ax{i+6}'].xaxis.set_minor_locator(AutoMinorLocator())
    locals()[f'ax{i+6}'].set_title(f'{times[i]} LT',y=1.02)
    locals()[f'ax{i+6}'].axvline(x=0, color='tab:gray', linestyle=':', linewidth=axiswidth*2)

plt.subplots_adjust(wspace=0.3,hspace=0.7)

#Save
plt.savefig(f'{outdir}/Figure_S3.jpg',bbox_inches='tight',dpi=500)
plt.savefig(f'{outdir}/Figure_S3.svg',bbox_inches='tight')
plt.savefig(f'{outdir}/Figure_S3.pdf',bbox_inches='tight')

plt.show()
plt.close()