# -*- coding: utf-8 -*-
"""
Created on Tue May 16 11:53:22 2023

@author: david
"""

import numpy as np
import matplotlib.pyplot as plt

# Generate some example data
x = np.linspace(0, 1, 1000)
y = np.linspace(0, 1, 1000)
X, Y = np.meshgrid(x, y)
Z = np.sin(100*X) + np.sin(200*(X+Y)) 

# Compute 2D FFT of the data
Z_fft = np.fft.fft2(Z)

# Generate frequency values for the x and y dimensions
freq_x = np.fft.fftfreq(x.size, x[1]-x[0])
freq_y = np.fft.fftfreq(y.size, y[1]-y[0])

# Shift the zero-frequency component to the center of the spectrum
Z_fft_shifted = np.fft.fftshift(Z_fft)
freq_x_shifted = np.fft.fftshift(freq_x)
freq_y_shifted = np.fft.fftshift(freq_y)

# Plot the original data and its FFT
fig, axs = plt.subplots(1, 2, figsize=(10, 4),dpi=300)
axs[0].imshow(Z, cmap='hot', extent=[x.min(), x.max(), y.min(), y.max()])
axs[0].set_title('Original data')
axs[0].set_xlabel('x')
axs[0].set_ylabel('y')
axs[1].imshow(np.abs(Z_fft_shifted), cmap='hot', extent=[freq_x_shifted.min(), freq_x_shifted.max(), freq_y_shifted.min(), freq_y_shifted.max()])
axs[1].set_title('2D FFT')
axs[1].set_xlabel('kx')
axs[1].set_ylabel('ky')
plt.show()