from dataclasses import dataclass from tools import clean_bounds, clean_bounds_offset from cache import timeit import scipy.signal import scipy.fftpack as sfft import scipy import numpy as np import logging logger = logging.getLogger("fft") @dataclass class FFT: freqx: np.ndarray = np.array([]) freqy: np.ndarray = np.array([]) intens: np.ndarray = np.array([]) def extents(self): return (np.min(self.freqx), np.max(self.freqx), np.min(self.freqy), np.max(self.freqy)) def clean(self): cutoff = 1e-10 total = np.sum(self.intens) low = np.sum(self.intens[self.intens < cutoff]) print(f"{low}/{total} = {low/total}") self.intens[self.intens < cutoff] = cutoff def val2pos(self, x, y): xind = np.searchsorted(self.freqx, x).astype(int) yind = np.searchsorted(self.freqy, y).astype(int) return xind, yind def save(self, filename): np.savez(filename, intens=self.intens, freqx=self.freqx, freqy=self.freqy) def load(self, filename): files = np.load(filename) self.freqx = files["freqx"] self.freqy = files["freqy"] self.intens = files["intens"] class SpinImage: resolution = 0.05 offset = 40 length_x: float length_y: float sigma: float index: list[tuple[np.ndarray, np.ndarray]] buffer: list[np.ndarray] intens_map: np.ndarray def make_list(self, x_pos, y_pos, x_inds, y_inds): x_ind = np.arange(0, self.length_x, self.resolution) y_ind = np.arange(0, self.length_y, self.resolution) X, Y = np.meshgrid(x_ind, y_ind, indexing="ij") out_list = [] for x, y, x_ind, y_ind in zip( x_pos.flatten(), y_pos.flatten(), x_inds.flatten(), y_inds.flatten(), ): xl, yl, xu, yu = clean_bounds_offset( x_ind, y_ind, self.offset, X.shape) out_list.append(np.exp(-0.5 * ((X[xl:xu, yl:yu] - x) ** 2 + (Y[xl:xu, yl:yu] - y) ** 2) / self.sigma**2)) # out_list.append(np.ones_like(X[xl:xu, yl:yu])) out_list = np.array(out_list, dtype=object) return out_list def __init__(self, phases: list[tuple[np.ndarray, np.ndarray]], sigma=.1): self.sigma = sigma zero_shift_x = 10000000 zero_shift_y = 10000000 max_len_x = 0 max_len_y = 0 for x_pos, y_pos in phases: assert x_pos.shape == y_pos.shape zero_shift_x = np.minimum(np.min(x_pos), zero_shift_x) zero_shift_y = np.minimum(np.min(y_pos), zero_shift_y) max_len_x = np.maximum(np.max(x_pos), max_len_x) max_len_y = np.maximum(np.max(y_pos), max_len_y) self.length_x = max_len_x + (2*self.offset + 10) * self.resolution self.length_y = max_len_y + (2*self.offset + 10) * self.resolution self.true_pos = [] self.index = [] self.buffer = [] offset_shift = (self.offset+3)*self.resolution for x_pos, y_pos in phases: x_pos = x_pos - zero_shift_x + offset_shift y_pos = y_pos - zero_shift_y + offset_shift x_index, y_index = self._to_index(x_pos, y_pos) self.index.append((x_index, y_index)) self.true_pos.append((x_pos, y_pos)) buffer = self.make_list(x_pos, y_pos, x_index, y_index) self.buffer.append(buffer) def _to_index(self, pos_x, pos_y): x_ind = np.arange(0, self.length_x, self.resolution) y_ind = np.arange(0, self.length_y, self.resolution) xind = np.searchsorted(x_ind, pos_x).astype(int) yind = np.searchsorted(y_ind, pos_y).astype(int) return xind, yind def _apply_mask(self, idx, buffer, mask): logger.info("_apply_mask") (x_indices, y_indices) = idx assert x_indices.shape == y_indices.shape assert x_indices.flatten().shape == mask.shape,\ f"Invalid mask: {mask.shape} != {x_indices.flatten().shape}" for x_idx, y_idx, dat in zip( x_indices.flatten()[mask], y_indices.flatten()[mask], buffer[mask] ): logger.debug("stamp") xl, yl, xu, yu = clean_bounds_offset( x_idx, y_idx, self.offset, self.img.shape) # if self.img[xl: xu, yl: yu].shape == dat.astype(np.float64).shape: self.img[xl: xu, yl: yu] += dat.astype(np.float64) def apply_mask(self, maske): self.img = np.zeros( (int(self.length_x/self.resolution), int(self.length_y/self.resolution))) checker = False self.intens_map = [] for counter, data in enumerate(zip(self.index, self.buffer)): (idx, buf) = data self._apply_mask(idx, buf, (maske == counter).flatten()) if counter in maske: checker = True self.intens_map.append(np.ones_like(maske)) assert checker, "No stamp set" def fft(self): Z_fft = sfft.fft2(self.img) Z_shift = sfft.fftshift(Z_fft) fft_freqx = sfft.fftfreq(self.img.shape[0], self.resolution) fft_freqy = sfft.fftfreq(self.img.shape[1], self.resolution) fft_freqx_clean = sfft.fftshift(fft_freqx) fft_freqy_clean = sfft.fftshift(fft_freqy) return FFT(freqx=fft_freqx_clean, freqy=fft_freqy_clean, intens=np.abs(Z_shift) ** 2) def gaussian(self, sigma): x = np.linspace(0, self.length_x, self.img.shape[0]) y = np.linspace(0, self.length_y, self.img.shape[1]) X, Y = np.meshgrid(x, y, indexing="ij") mu_y = (self.length_x / 2.) mu_x = (self.length_y / 2.) z = 1 / (2 * np.pi * sigma * sigma) * \ np.exp(-((X - mu_x)**2 / (2 * sigma**2) + (Y-mu_y)**2 / (2 * sigma**2))) self.img = np.multiply(self.img, z) for idx, map in zip(self.true_pos, self.intens_map): (X, Y) = idx z = 1 / (2 * np.pi * sigma * sigma) * \ np.exp(-((X - mu_x)**2 / (2 * sigma**2) + (Y-mu_y)**2 / (2 * sigma**2))) map *= z def get_intens(self, mask): intens = [] for counter, data in enumerate(self.intens_map): intensity = np.sum(data[mask == counter]) intens.append(intensity) return intens