FFT/software/fft_sim/spin_image.py
2023-05-08 11:20:51 +02:00

168 lines
6.3 KiB
Python

from dataclasses import dataclass
from .tools import clean_bounds, clean_bounds_offset
from .cache import timeit
import scipy.signal
import scipy.fftpack as sfft
import scipy
import numpy as np
import logging
logger = logging.getLogger("fft")
@dataclass
class FFT:
freqx: np.ndarray = np.array([])
freqy: np.ndarray = np.array([])
intens: np.ndarray = np.array([])
def extents(self):
return (np.min(self.freqx), np.max(self.freqx), np.min(self.freqy), np.max(self.freqy))
def clean(self):
cutoff = 1e-10
total = np.sum(self.intens)
low = np.sum(self.intens[self.intens < cutoff])
print(f"{low}/{total} = {low/total}")
self.intens[self.intens < cutoff] = cutoff
def val2pos(self, x, y):
xind = np.searchsorted(self.freqx, x).astype(int)
yind = np.searchsorted(self.freqy, y).astype(int)
return xind, yind
def save(self, filename):
np.savez(filename, intens=self.intens,
freqx=self.freqx, freqy=self.freqy)
def load(self, filename):
files = np.load(filename)
self.freqx = files["freqx"]
self.freqy = files["freqy"]
self.intens = files["intens"]
class SpinImage:
resolution = 0.05
offset = 40
length_x: float
length_y: float
sigma: float
index: list[tuple[np.ndarray, np.ndarray]]
buffer: list[np.ndarray]
intens_map: np.ndarray
def make_list(self, x_pos, y_pos, x_inds, y_inds):
x_ind = np.arange(0, self.length_x, self.resolution)
y_ind = np.arange(0, self.length_y, self.resolution)
X, Y = np.meshgrid(x_ind, y_ind, indexing="ij")
out_list = []
for x, y, x_ind, y_ind in zip(
x_pos.flatten(),
y_pos.flatten(),
x_inds.flatten(),
y_inds.flatten(),
):
xl, yl, xu, yu = clean_bounds_offset(
x_ind, y_ind, self.offset, X.shape)
out_list.append(np.exp(-0.5 * ((X[xl:xu, yl:yu] - x) ** 2 +
(Y[xl:xu, yl:yu] - y) ** 2) / self.sigma**2))
# out_list.append(np.ones_like(X[xl:xu, yl:yu]))
out_list = np.array(out_list, dtype=object)
return out_list
def __init__(self, phases: list[tuple[np.ndarray, np.ndarray]], sigma=.1):
self.sigma = sigma
zero_shift_x = 10000000
zero_shift_y = 10000000
max_len_x = 0
max_len_y = 0
for x_pos, y_pos in phases:
assert x_pos.shape == y_pos.shape
zero_shift_x = np.minimum(np.min(x_pos), zero_shift_x)
zero_shift_y = np.minimum(np.min(y_pos), zero_shift_y)
max_len_x = np.maximum(np.max(x_pos), max_len_x)
max_len_y = np.maximum(np.max(y_pos), max_len_y)
self.length_x = max_len_x + (2*self.offset + 10) * self.resolution
self.length_y = max_len_y + (2*self.offset + 10) * self.resolution
self.true_pos = []
self.index = []
self.buffer = []
offset_shift = (self.offset+3)*self.resolution
for x_pos, y_pos in phases:
x_pos = x_pos - zero_shift_x + offset_shift
y_pos = y_pos - zero_shift_y + offset_shift
x_index, y_index = self._to_index(x_pos, y_pos)
self.index.append((x_index, y_index))
self.true_pos.append((x_pos, y_pos))
buffer = self.make_list(x_pos, y_pos, x_index, y_index)
self.buffer.append(buffer)
def _to_index(self, pos_x, pos_y):
x_ind = np.arange(0, self.length_x, self.resolution)
y_ind = np.arange(0, self.length_y, self.resolution)
xind = np.searchsorted(x_ind, pos_x).astype(int)
yind = np.searchsorted(y_ind, pos_y).astype(int)
return xind, yind
def _apply_mask(self, idx, buffer, mask):
logger.info("_apply_mask")
(x_indices, y_indices) = idx
assert x_indices.shape == y_indices.shape
assert x_indices.flatten().shape == mask.shape,\
f"Invalid mask: {mask.shape} != {x_indices.flatten().shape}"
for x_idx, y_idx, dat in zip(
x_indices.flatten()[mask],
y_indices.flatten()[mask],
buffer[mask]
):
logger.debug("stamp")
xl, yl, xu, yu = clean_bounds_offset(
x_idx, y_idx, self.offset, self.img.shape)
# if self.img[xl: xu, yl: yu].shape == dat.astype(np.float64).shape:
self.img[xl: xu, yl: yu] += dat.astype(np.float64)
def apply_mask(self, maske):
self.img = np.zeros(
(int(self.length_x/self.resolution), int(self.length_y/self.resolution)))
checker = False
self.intens_map = []
for counter, data in enumerate(zip(self.index, self.buffer)):
(idx, buf) = data
self._apply_mask(idx, buf, (maske == counter).flatten())
if counter in maske:
checker = True
self.intens_map.append(np.ones_like(maske))
assert checker, "No stamp set"
def fft(self):
Z_fft = sfft.fft2(self.img)
Z_shift = sfft.fftshift(Z_fft)
fft_freqx = sfft.fftfreq(self.img.shape[0], self.resolution)
fft_freqy = sfft.fftfreq(self.img.shape[1], self.resolution)
fft_freqx_clean = sfft.fftshift(fft_freqx)
fft_freqy_clean = sfft.fftshift(fft_freqy)
return FFT(freqx=fft_freqx_clean, freqy=fft_freqy_clean, intens=np.abs(Z_shift) ** 2)
def gaussian(self, sigma):
x = np.linspace(0, self.length_x, self.img.shape[0])
y = np.linspace(0, self.length_y, self.img.shape[1])
X, Y = np.meshgrid(x, y, indexing="ij")
mu_y = (self.length_x / 2.)
mu_x = (self.length_y / 2.)
z = 1 / (2 * np.pi * sigma * sigma) * \
np.exp(-((X - mu_x)**2 / (2 * sigma**2) + (Y-mu_y)**2 / (2 * sigma**2)))
self.img = np.multiply(self.img, z)
for idx, map in zip(self.true_pos, self.intens_map):
(X, Y) = idx
z = 1 / (2 * np.pi * sigma * sigma) * \
np.exp(-((X - mu_x)**2 / (2 * sigma**2) +
(Y-mu_y)**2 / (2 * sigma**2)))
map *= z
def get_intens(self, mask):
intens = []
for counter, data in enumerate(self.intens_map):
intensity = np.sum(data[mask == counter])
intens.append(intensity)
return intens