168 lines
6.3 KiB
Python
168 lines
6.3 KiB
Python
from dataclasses import dataclass
|
|
from .tools import clean_bounds, clean_bounds_offset
|
|
from .cache import timeit
|
|
import scipy.signal
|
|
import scipy.fftpack as sfft
|
|
import scipy
|
|
import numpy as np
|
|
import logging
|
|
logger = logging.getLogger("fft")
|
|
|
|
|
|
@dataclass
|
|
class FFT:
|
|
freqx: np.ndarray = np.array([])
|
|
freqy: np.ndarray = np.array([])
|
|
intens: np.ndarray = np.array([])
|
|
|
|
def extents(self):
|
|
return (np.min(self.freqx), np.max(self.freqx), np.min(self.freqy), np.max(self.freqy))
|
|
|
|
def clean(self):
|
|
cutoff = 1e-10
|
|
total = np.sum(self.intens)
|
|
low = np.sum(self.intens[self.intens < cutoff])
|
|
print(f"{low}/{total} = {low/total}")
|
|
self.intens[self.intens < cutoff] = cutoff
|
|
|
|
def val2pos(self, x, y):
|
|
xind = np.searchsorted(self.freqx, x).astype(int)
|
|
yind = np.searchsorted(self.freqy, y).astype(int)
|
|
return xind, yind
|
|
|
|
def save(self, filename):
|
|
np.savez(filename, intens=self.intens,
|
|
freqx=self.freqx, freqy=self.freqy)
|
|
|
|
def load(self, filename):
|
|
files = np.load(filename)
|
|
self.freqx = files["freqx"]
|
|
self.freqy = files["freqy"]
|
|
self.intens = files["intens"]
|
|
|
|
|
|
class SpinImage:
|
|
resolution = 0.05
|
|
offset = 40
|
|
length_x: float
|
|
length_y: float
|
|
sigma: float
|
|
index: list[tuple[np.ndarray, np.ndarray]]
|
|
buffer: list[np.ndarray]
|
|
intens_map: np.ndarray
|
|
|
|
def make_list(self, x_pos, y_pos, x_inds, y_inds):
|
|
x_ind = np.arange(0, self.length_x, self.resolution)
|
|
y_ind = np.arange(0, self.length_y, self.resolution)
|
|
X, Y = np.meshgrid(x_ind, y_ind, indexing="ij")
|
|
out_list = []
|
|
for x, y, x_ind, y_ind in zip(
|
|
x_pos.flatten(),
|
|
y_pos.flatten(),
|
|
x_inds.flatten(),
|
|
y_inds.flatten(),
|
|
):
|
|
xl, yl, xu, yu = clean_bounds_offset(
|
|
x_ind, y_ind, self.offset, X.shape)
|
|
out_list.append(np.exp(-0.5 * ((X[xl:xu, yl:yu] - x) ** 2 +
|
|
(Y[xl:xu, yl:yu] - y) ** 2) / self.sigma**2))
|
|
# out_list.append(np.ones_like(X[xl:xu, yl:yu]))
|
|
out_list = np.array(out_list, dtype=object)
|
|
return out_list
|
|
|
|
def __init__(self, phases: list[tuple[np.ndarray, np.ndarray]], sigma=.1):
|
|
self.sigma = sigma
|
|
zero_shift_x = 10000000
|
|
zero_shift_y = 10000000
|
|
max_len_x = 0
|
|
max_len_y = 0
|
|
for x_pos, y_pos in phases:
|
|
assert x_pos.shape == y_pos.shape
|
|
zero_shift_x = np.minimum(np.min(x_pos), zero_shift_x)
|
|
zero_shift_y = np.minimum(np.min(y_pos), zero_shift_y)
|
|
max_len_x = np.maximum(np.max(x_pos), max_len_x)
|
|
max_len_y = np.maximum(np.max(y_pos), max_len_y)
|
|
self.length_x = max_len_x + (2*self.offset + 10) * self.resolution
|
|
self.length_y = max_len_y + (2*self.offset + 10) * self.resolution
|
|
self.true_pos = []
|
|
self.index = []
|
|
self.buffer = []
|
|
offset_shift = (self.offset+3)*self.resolution
|
|
for x_pos, y_pos in phases:
|
|
x_pos = x_pos - zero_shift_x + offset_shift
|
|
y_pos = y_pos - zero_shift_y + offset_shift
|
|
x_index, y_index = self._to_index(x_pos, y_pos)
|
|
self.index.append((x_index, y_index))
|
|
self.true_pos.append((x_pos, y_pos))
|
|
buffer = self.make_list(x_pos, y_pos, x_index, y_index)
|
|
self.buffer.append(buffer)
|
|
|
|
def _to_index(self, pos_x, pos_y):
|
|
x_ind = np.arange(0, self.length_x, self.resolution)
|
|
y_ind = np.arange(0, self.length_y, self.resolution)
|
|
xind = np.searchsorted(x_ind, pos_x).astype(int)
|
|
yind = np.searchsorted(y_ind, pos_y).astype(int)
|
|
return xind, yind
|
|
|
|
def _apply_mask(self, idx, buffer, mask):
|
|
logger.info("_apply_mask")
|
|
(x_indices, y_indices) = idx
|
|
assert x_indices.shape == y_indices.shape
|
|
assert x_indices.flatten().shape == mask.shape,\
|
|
f"Invalid mask: {mask.shape} != {x_indices.flatten().shape}"
|
|
for x_idx, y_idx, dat in zip(
|
|
x_indices.flatten()[mask],
|
|
y_indices.flatten()[mask],
|
|
buffer[mask]
|
|
):
|
|
logger.debug("stamp")
|
|
xl, yl, xu, yu = clean_bounds_offset(
|
|
x_idx, y_idx, self.offset, self.img.shape)
|
|
# if self.img[xl: xu, yl: yu].shape == dat.astype(np.float64).shape:
|
|
self.img[xl: xu, yl: yu] += dat.astype(np.float64)
|
|
|
|
def apply_mask(self, maske):
|
|
self.img = np.zeros(
|
|
(int(self.length_x/self.resolution), int(self.length_y/self.resolution)))
|
|
checker = False
|
|
self.intens_map = []
|
|
for counter, data in enumerate(zip(self.index, self.buffer)):
|
|
(idx, buf) = data
|
|
self._apply_mask(idx, buf, (maske == counter).flatten())
|
|
if counter in maske:
|
|
checker = True
|
|
self.intens_map.append(np.ones_like(maske))
|
|
assert checker, "No stamp set"
|
|
|
|
def fft(self):
|
|
Z_fft = sfft.fft2(self.img)
|
|
Z_shift = sfft.fftshift(Z_fft)
|
|
fft_freqx = sfft.fftfreq(self.img.shape[0], self.resolution)
|
|
fft_freqy = sfft.fftfreq(self.img.shape[1], self.resolution)
|
|
fft_freqx_clean = sfft.fftshift(fft_freqx)
|
|
fft_freqy_clean = sfft.fftshift(fft_freqy)
|
|
return FFT(freqx=fft_freqx_clean, freqy=fft_freqy_clean, intens=np.abs(Z_shift) ** 2)
|
|
|
|
def gaussian(self, sigma):
|
|
x = np.linspace(0, self.length_x, self.img.shape[0])
|
|
y = np.linspace(0, self.length_y, self.img.shape[1])
|
|
X, Y = np.meshgrid(x, y, indexing="ij")
|
|
mu_y = (self.length_x / 2.)
|
|
mu_x = (self.length_y / 2.)
|
|
z = 1 / (2 * np.pi * sigma * sigma) * \
|
|
np.exp(-((X - mu_x)**2 / (2 * sigma**2) + (Y-mu_y)**2 / (2 * sigma**2)))
|
|
self.img = np.multiply(self.img, z)
|
|
for idx, map in zip(self.true_pos, self.intens_map):
|
|
(X, Y) = idx
|
|
z = 1 / (2 * np.pi * sigma * sigma) * \
|
|
np.exp(-((X - mu_x)**2 / (2 * sigma**2) +
|
|
(Y-mu_y)**2 / (2 * sigma**2)))
|
|
map *= z
|
|
|
|
def get_intens(self, mask):
|
|
intens = []
|
|
for counter, data in enumerate(self.intens_map):
|
|
intensity = np.sum(data[mask == counter])
|
|
intens.append(intensity)
|
|
return intens
|