FFT/2d_fourie/main.py
2023-02-16 10:07:19 +01:00

404 lines
11 KiB
Python

from lattices import SCC_Lattice
import numpy as np
import matplotlib.pyplot as plt
import scipy.fftpack as sfft
import matplotlib.patches as patches
import matplotlib
import scipy
import scipy.signal
import tqdm
class SpinImage:
resolution = 0.1
def __init__(self, x_pos, y_pos):
self.length_x = np.max(x_pos) + self.resolution
self.length_y = np.max(y_pos) + self.resolution
self.img = self.image_from_pos(x_pos, y_pos)
def image_from_pos(self, pos_x, pos_y):
x_ind = np.arange(0, self.length_x, self.resolution) # angstrom
y_ind = np.arange(0, self.length_y, self.resolution) # angstrom
img = np.zeros((x_ind.size, y_ind.size))
xind = np.searchsorted(x_ind, pos_x)
yind = np.searchsorted(y_ind, pos_y)
img[xind, yind] = 1
return img
def fft(self):
Z_fft = sfft.fft2(self.img)
Z_shift = sfft.fftshift(Z_fft)
fft_freqx = sfft.fftfreq(self.img.shape[0], self.resolution)
fft_freqy = sfft.fftfreq(self.img.shape[1], self.resolution)
fft_freqx_clean = sfft.fftshift(fft_freqx)
fft_freqy_clean = sfft.fftshift(fft_freqy)
return fft_freqx_clean, fft_freqy_clean, np.abs(Z_shift) ** 2
def pad_it_square(self, additional_pad=0):
h = self.img.shape[0]
w = self.img.shape[1]
print(h, w)
xx = np.maximum(h, w) + 2 * additional_pad
yy = xx
self.length_x = xx * self.resolution
self.length_y = yy * self.resolution
print("Pad to: ", xx, yy)
a = (xx - h) // 2
aa = xx - a - h
b = (yy - w) // 2
bb = yy - b - w
self.img = np.pad(self.img, pad_width=(
(a, aa), (b, bb)), mode="constant")
def gaussian(self, sigma):
x = np.arange(-self.length_x/2,
self.length_x/2, self.resolution)
y = np.arange(-self.length_y/2,
self.length_y/2, self.resolution)
X, Y = np.meshgrid(x, y)
z = (
1 / (2 * np.pi * sigma * sigma)
* np.exp(-(X**2 / (2 * sigma**2) + Y**2 / (2 * sigma**2)))
)
self.img = np.multiply(self.img, z.T)
def plot(self, ax, scale=None):
if scale is None:
ax.imshow(self.img)
else:
quad = np.ones((int(scale/self.resolution),
int(scale/self.resolution)))
img = scipy.signal.convolve2d(self.img, quad)
ax.imshow(img)
def blur(self, sigma):
self.img = scipy.ndimage.gaussian_filter(self.img, sigma)
def plot(freqx, freqy, intens, ax_log=None, ax_lin=None):
#point_x, point_y = reci_rutile()
# for px, py in zip(point_x, point_y):
# rect = rect_at_point(px, py, "r")
# ax.add_patch(rect)
# ax.text(
# px, py, f"{reduce(extract_rect(intens, px, py, freqx, freqy)):2.2}", clip_on=True
# )
#point_x, point_y = reci_mono()
# for px, py in zip(point_x, point_y):
# rect = rect_at_point(px, py, "b")
# ax.add_patch(rect)
# ax.text(
# px, py, f"{reduce(extract_rect(intens, px, py, freqx, freqy)):2.2}", clip_on=True
# )
if ax_log:
t = ax_log.imshow(
intens,
extent=(np.min(freqx), np.max(freqx),
np.min(freqy), np.max(freqy)),
norm=matplotlib.colors.LogNorm(),
cmap="viridis"
)
plt.colorbar(t)
if ax_lin:
t = ax_lin.imshow(
intens,
extent=(np.min(freqx), np.max(freqx),
np.min(freqy), np.max(freqy)),
cmap="viridis"
)
plt.colorbar(t)
def rotate(x, y, angle):
radian = angle / 180 * 2 * np.pi
return np.cos(radian) * x - np.sin(radian) * y, np.sin(radian) * x + np.cos(radian) * y
def test_square():
lat = SCC_Lattice(40, 40)
pos_x, pos_y = lat.get_from_mask(None)
pos_x, pos_y = rotate(pos_x, pos_y,30)
si = SpinImage(pos_x, pos_y)
fig, axs = plt.subplots(2, 2)
si.pad_it_square(10)
si.plot(axs[0, 0], 2)
si.gaussian(300)
# si.blur(3)
si.plot(axs[0, 1], 2)
plt.pause(0.1)
fx, fy, intens = si.fft()
plot(fx, fy, intens, axs[1, 0], axs[1, 1])
print("Done")
plt.savefig("test.png")
plt.show()
if __name__ == "__main__":
test_square()
# def test_lattice():
# lat = VO2_Lattice(10, 10)
# maske = np.zeros((10, 10), dtype=bool)
# x, y = lat.get_from_mask(maske)
#
# plt.scatter(x, y)
# maske = np.invert(maske)
# x, y = lat.get_from_mask(maske)
# plt.scatter(x, y)
#
# maske[:3, :5] = False
# x, y = lat.get_from_mask(maske)
# plt.scatter(x, y)
# plt.show()
#
#
# self.resolution = 0.1
# CMAP = "Greys"
#
#
# def test_img():
# lat = VO2_Lattice(10, 10)
# maske = np.ones((10, 10), dtype=bool)
# x, y = lat.get_from_mask(maske)
# img = image_from_pos(x, y)
# plt.imshow(img.T, origin="lower", extent=(0, np.max(x), 0, np.max(y)))
# plt.scatter(x, y)
# plt.show()
#
#
# def gaussian(img):
# x = np.arange(-self.resolution * img.shape[0]/2,
# self.resolution * img.shape[0]/2, self.resolution)
# y = np.arange(-self.resolution * img.shape[1]/2,
# self.resolution * img.shape[1]/2, self.resolution)
# X, Y = np.meshgrid(x, y)
# sigma = self.resolution * img.shape[0] / 10
# print("Sigma: ", sigma)
# z = (
# 1 / (2 * np.pi * sigma * sigma)
# * np.exp(-(X**2 / (2 * sigma**2) + Y**2 / (2 * sigma**2)))
# )
# return np.multiply(img, z.T)
#
#
# def rect_at_point(x, y, color):
# length_2 = 0.08
# rect = patches.Rectangle(
# (x - length_2, y - length_2),
# 2 * length_2,
# 2 * length_2,
# linewidth=1,
# edgecolor=color,
# facecolor="none",
# )
# return rect
#
#
# def reci_rutile():
# x = np.arange(-2, 3)
# y = np.arange(-2, 3)
# X, Y = np.meshgrid(x, y)
# return (X * 0.22 + Y * 0.44).flatten(), (X * 0.349).flatten()
#
#
# def reci_mono():
# x, y = reci_rutile()
# return x + 0.1083, y + 0.1719
#
#
# def draw_big_val_rect(img, x, y, x_index, y_index):
# length_2 = 0.08
# pos_x_lower = x - length_2
# pos_x_upper = x + length_2
#
# pos_y_lower = y - length_2
# pos_y_upper = y + length_2
# x_lower = np.searchsorted(x_index, pos_x_lower)
# x_upper = np.searchsorted(x_index, pos_x_upper)
# y_lower = np.searchsorted(y_index, pos_y_lower)
# y_upper = np.searchsorted(y_index, pos_y_upper)
#
# img[y_lower:y_upper, x_lower:x_upper] = 1e4
# return img
#
#
# def extract_rect(img, x, y, x_index, y_index):
# length_2 = 0.08
#
# pos_x_lower = x - length_2
# pos_x_upper = x + length_2
#
# pos_y_lower = y - length_2
# pos_y_upper = y + length_2
#
# x_lower = np.searchsorted(x_index, pos_x_lower)
# x_upper = np.searchsorted(x_index, pos_x_upper)
#
# y_lower = np.searchsorted(y_index, pos_y_lower)
# y_upper = np.searchsorted(y_index, pos_y_upper)
#
# # fix different number of spins possible
# if x_upper - x_lower < 10:
# x_upper += 1
# if y_upper - y_lower < 10:
# y_upper += 1
#
# return img[y_lower:y_upper, x_lower:x_upper]
#
#
# def extract_peaks(freqx, freqy, intens):
# rutile = []
# point_x, point_y = reci_rutile()
# for px, py in zip(point_x, point_y):
# rutile.append(reduce(extract_rect(intens, px, py, freqx, freqy)))
#
# mono = []
# point_x, point_y = reci_mono()
# for px, py in zip(point_x, point_y):
# mono.append(reduce(extract_rect(intens, px, py, freqx, freqy)))
# return rutile, mono
#
#
# def plot(ax, freqx, freqy, intens):
# point_x, point_y = reci_rutile()
# for px, py in zip(point_x, point_y):
# rect = rect_at_point(px, py, "r")
# ax.add_patch(rect)
# ax.text(
# px, py, f"{reduce(extract_rect(intens, px, py, freqx, freqy)):2.2}", clip_on=True
# )
#
# point_x, point_y = reci_mono()
# for px, py in zip(point_x, point_y):
# rect = rect_at_point(px, py, "b")
# ax.add_patch(rect)
# ax.text(
# px, py, f"{reduce(extract_rect(intens, px, py, freqx, freqy)):2.2}", clip_on=True
# )
# ax.imshow(
# intens,
# extent=(np.min(freqx), np.max(freqx), np.min(freqy), np.max(freqy)),
# norm=matplotlib.colors.LogNorm(),
# cmap="Greys"
# )
#
#
# def test_all():
# LEN = 100
# SIZE = 60 * LEN + 1
# quad = np.ones((3, 3))
#
# fig, ax = plt.subplots(1, 3)
#
# lat = VO2_Lattice(LEN, LEN)
# maske = np.ones((LEN, LEN), dtype=bool)
# x, y = lat.get_from_mask(maske)
# img = image_from_pos(x, y)
# img = padding(img, SIZE, SIZE)
# #img = scipy.signal.convolve2d(img, quad)
# img = gaussian(img)
# freqx, freqy, intens_rutile = fft(img)
#
# img = scipy.signal.convolve2d(img, quad)
# ax[0].imshow(img)
#
# maske = np.zeros((LEN, LEN), dtype=bool)
# x, y = lat.get_from_mask(maske)
# img = image_from_pos(x, y)
# img = padding(img, SIZE, SIZE)
# img = gaussian(img)
# freqx, freqy, intens_mono = fft(img)
#
# img = scipy.signal.convolve2d(img, quad)
# ax[2].imshow(img)
#
# maske = np.zeros((LEN, LEN), dtype=bool)
# ind = np.arange(LEN*LEN)
# np.random.shuffle(ind)
# ind = np.unravel_index(ind[:int(LEN*LEN/2)], (LEN, LEN))
# maske[ind] = True
# x, y = lat.get_from_mask(maske)
# img = image_from_pos(x, y)
# img = padding(img, SIZE, SIZE)
# img = gaussian(img)
# freqx, freqy, intens_mono = fft(img)
#
# img = scipy.signal.convolve2d(img, quad)
# ax[1].imshow(img)
#
# print(np.mean(maske))
# x, y = lat.get_from_mask(maske)
# img = image_from_pos(x, y)
# img = padding(img, SIZE, SIZE)
# img = gaussian(img)
# freqx, freqy, intens_50 = fft(img)
#
# fig, axs = plt.subplots(1, 3)
# plot(axs[0], freqx=freqx, freqy=freqy, intens=intens_rutile)
# plot(axs[2], freqx=freqx, freqy=freqy, intens=intens_mono)
# plot(axs[1], freqx=freqx, freqy=freqy, intens=intens_50)
# axs[0].set_title("Rutile")
# axs[2].set_title("Mono")
# axs[1].set_title("50/50")
#
# for ax in axs:
# ax.set_xlim(-1.0, 1.0)
# ax.set_ylim(-1.0, 1.0)
#
#
# def eval(maske, lat, LEN):
# x, y = lat.get_from_mask(maske)
# SIZE = 60 * LEN + 1
# img = image_from_pos(x, y)
# img = padding(img, SIZE, SIZE)
# img = gaussian(img)
# freqx, freqy, intens = fft(img)
# return extract_peaks(freqx, freqy, intens)
#
#
# def reduce(arr):
# arr = np.array(arr)
# arr = arr.flatten()
# return np.sum(arr[np.argpartition(arr, -8)[-8:]])
#
#
# def main():
# LEN = 80
# lat = VO2_Lattice(LEN, LEN)
# maske = np.zeros((LEN, LEN), dtype=bool)
# ind = np.arange(LEN*LEN)
# np.random.shuffle(ind)
# percentage = []
# rutile = []
# monoclinic = []
# counter = 0
# for i in tqdm.tqdm(ind):
# i_unravel = np.unravel_index(i, (LEN, LEN))
# maske[i_unravel] = True
# if np.mod(counter, 300) == 0:
# rut, mono = eval(maske, lat, LEN)
# percentage.append(np.mean(maske))
# rutile.append(reduce(rut))
# monoclinic.append(reduce(mono))
# counter += 1
#
# print(len(percentage), len(mono), len(rutile))
# print(mono)
# plt.figure()
# plt.scatter(percentage, np.array(monoclinic)/monoclinic[0], label="mono")
# plt.scatter(percentage, np.array(rutile)/rutile[0], label="rut")
# plt.legend()
#
#
# if __name__ == "__main__":
# test_all()
# # main()
# plt.show()