FFT/software/extractors.py
2023-05-03 09:48:15 +02:00

106 lines
3.2 KiB
Python

import logging
from abc import abstractmethod, ABC
from tools import clean_bounds_offset
from spin_image import FFT
from cache import persist_to_file, timeit
import tqdm
import numpy as np
import matplotlib.pyplot as plt
# from scipy.spatial import Voronoi
# import cv2
logger = logging.getLogger("fft")
class Evaluator(ABC):
eval_points: list[np.ndarray]
def extract(self, img: FFT):
all_val = []
all_idx = []
for ev_points in self.eval_points:
logger.info(f"Extracting points: {ev_points}")
temp_val = []
temp_idx = []
for num in ev_points:
if np.sum(self.mask == num) == 0:
continue
temp_val.append(np.sum(img.intens[self.mask == num]))
temp_idx.append(num)
all_val.append(temp_val)
all_idx.append(temp_idx)
all_val.append([np.sum(img.intens[self.mask == -1])])
all_idx.append([-1])
return all_idx, all_val
def debug(self, img: FFT):
for count, ev_points in enumerate(self.eval_points, start=0):
for num in ev_points:
img.intens[self.mask == num] += count
return img
def get_mask(self):
return self.mask
@timeit
def generate_mask(self, img: FFT, merge=False):
hash = str(img.intens.shape)
self.mask = self.gen_mask_helper(img)
if merge:
self.mask = self.merge_mask_helper()
self.eval_points = [[a] for a in np.arange(len(self.eval_points))]
@abstractmethod
def gen_mask_helper(self, img: FFT, hash=str):
pass
@abstractmethod
def merge_mask_helper(self, hash: str):
pass
def purge(self, img: FFT):
img.intens[self.mask != -1] = 0
class Rect_Evaluator(Evaluator):
# def __init__(self, points, eval_points):
# self.eval_points = eval_points
# self.points = points
# self.length = 4
def __init__(self, spots: list[tuple[np.ndarray, np.ndarray]], length: int = 6):
self.spots = spots
self.length = length
self.eval_points = []
start = 0
for sp in spots:
self.eval_points.append(np.arange(start, start+sp[0].size))
start += sp[0].size
def merge_mask_helper(self):
new_eval_points = np.arange(len(self.eval_points))
mask = self.mask.copy()
for nc, ev_points in zip(new_eval_points, self.eval_points):
maske_low = np.min(ev_points) <= self.mask
maske_high = np.max(ev_points) >= self.mask
mask[np.logical_and(maske_high, maske_low)] = nc
plt.figure()
plt.imshow(mask)
plt.figure()
plt.imshow(self.mask)
return mask
def gen_mask_helper(self, img: FFT):
mask = np.full_like(img.intens, -1)
count = 0
for spot_group in self.spots:
logger.debug(f"Spot: {spot_group}")
(x, y) = spot_group
x, y = img.val2pos(x, y)
for x_p, y_p in zip(x, y):
xl, yl, xu, yu = clean_bounds_offset(
x_p, y_p, self.length, img.intens.shape)
mask[xl:xu, yl:yu] = count
count += 1
return mask