Clean reimplementation added

This commit is contained in:
Jacob Holder 2023-04-19 15:36:30 +02:00
parent 5ed55ed759
commit 5e6c84ebf5
Signed by: jacob
GPG Key ID: 2194FC747048A7FD
10 changed files with 999 additions and 0 deletions

95
clean_python/analysis.py Normal file
View File

@ -0,0 +1,95 @@
import sys
import numpy as np
import matplotlib.pyplot as plt
import glob
import scipy.interpolate as ip
plt.style.use(["style", "colors", "two_column"])
def check_percentage(p1, p2):
plt.figure()
plt.plot(p1, p2)
def merge(files):
merge = []
for file in files:
data = np.load(file, allow_pickle=True)
old_percentage = data["percentage"]
w_percentage = data["w_percentage"]
# check_percentage(old_percentage, w_percentage)
percentage = w_percentage
out = []
for o in ["out_1", "out_2", "out_3", "out_4"]:
out.append(np.array(data[o]))
out = np.array(out)[:, :, 0]
summe = np.max(np.sum(out, axis=0))
out = out / summe
merge.append(out)
print(merge)
merge = sum(merge)
summe = np.max(np.sum(merge, axis=0))
merge = merge / summe
print(merge)
return percentage, merge
def debug(percentage, out):
plt.figure()
for o in out:
plt.plot(percentage, o)
def stacked_plot(percentage, out, title=""):
plt.figure()
stacks = plt.stackplot(percentage, out[[0, 3, 1, 2], :], colors=[
"w"], ls="solid", ec="k")
hatches = ["/", "", "\\", "\\"]
for stack, hatch in zip(stacks, hatches):
stack.set_hatch(hatch)
plt.xlabel("Insulating Phase (%)")
plt.ylabel("normalized Intensity ")
plt.ylim([0.4, 1])
plt.xlim([0., 1])
plt.tight_layout()
plt.text(0.1, 0.9, "monoclinic", backgroundcolor="w")
plt.text(0.6, 0.5, "rutile", backgroundcolor="w")
plt.text(0.35, 0.75, "diffusive", backgroundcolor="w")
plt.title(title)
def time_scale(p, o):
rut_perc = o[0]
rut_perc = rut_perc - np.min(rut_perc)
rut_perc /= np.max(rut_perc)
mono_perc = -o[2]
mono_perc = mono_perc - np.min(mono_perc)
mono_perc /= np.max(mono_perc)
cs_rut = ip.CubicSpline(p[::-1], rut_perc[::-1])
cs_mono = ip.CubicSpline(p[::-1], mono_perc[::-1])
plt.figure()
ph = np.linspace(0, 1, 100)
plt.plot(ph, cs_rut(ph))
plt.plot(ph, cs_mono(ph))
time = np.linspace(0, 3, 1000)
phy_phase = np.exp(-time)
rut_phase = cs_rut(phy_phase)
mono_phase = cs_mono(phy_phase)
plt.figure()
plt.plot(time, phy_phase)
plt.plot(time, rut_phase)
plt.plot(time, mono_phase)
if __name__ == "__main__":
p, o = merge(sys.argv[1:])
# eval_data_print(f)
stacked_plot(p, o)
# debug(p, o)
time_scale(p, o)
plt.show()

69
clean_python/cache.py Normal file
View File

@ -0,0 +1,69 @@
import numpy as np
import logging
from functools import wraps
import time
def timeit(func):
@wraps(func)
def timeit_wrapper(*args, **kwargs):
start_time = time.perf_counter()
logging.info(f"Start Function {func.__name__}:")
result = func(*args, **kwargs)
end_time = time.perf_counter()
total_time = end_time - start_time
# first item in the args, ie `args[0]` is `self`
logging.info(
f'Function {func.__name__} Took {total_time:.4f} seconds')
return result
return timeit_wrapper
def persist_to_file(file_name):
def decorator(original_func):
try:
file_nam = file_name
if file_name[-4:] != ".npz":
file_nam += ".npz"
file = np.load(file_nam)
cache = dict(zip((file.files), (file[k] for k in file.files)))
except (IOError, ValueError):
cache = {}
def hash_func(*param):
key = str(param)
return key
def persist_func(*param, hash=None):
if hash is None:
hash = hash_func(*param)
print("Hash: ", hash)
if cache == {} or ("hash" not in cache) or hash != cache["hash"]:
print("recalc")
data = original_func(*param)
np.savez(file_name, hash=hash, dat=data)
cache["hash"] = hash
cache["dat"] = data
print("loaded")
return cache["dat"]
return persist_func
return decorator
# test = 1234
#
#
# @persist_to_file("cache.npz", test)
# def test():
# print("calculate")
# return np.zeros((100, 100))
#
#
# if __name__ == "__main__":
# test()
# test()
# pass

161
clean_python/extractors.py Normal file
View File

@ -0,0 +1,161 @@
import numpy as np
import tqdm
# from scipy.spatial import Voronoi
# import cv2
from cache import persist_to_file, timeit
from spin_image import FFT
from tools import clean_bounds_offset
from abc import abstractmethod, ABC
import logging
logger = logging.getLogger("fft")
class Evaluator(ABC):
eval_points: list[np.ndarray]
def extract(self, img: FFT):
all_val = []
all_idx = []
for ev_points in self.eval_points:
logger.info(f"Extracting points: {ev_points}")
temp_val = []
temp_idx = []
for num in ev_points:
if np.sum(self.mask == num) == 0:
continue
temp_val.append(np.sum(img.intens[self.mask == num]))
temp_idx.append(num)
all_val.append(temp_val)
all_idx.append(temp_idx)
all_val.append([np.sum(img.intens[self.mask == -1])])
all_idx.append([-1])
return all_idx, all_val
def debug(self, img: FFT):
for count, ev_points in enumerate(self.eval_points, start=0):
for num in ev_points:
img.intens[self.mask == num] += count
return img
def get_mask(self):
return self.mask
@timeit
def generate_mask(self, img: FFT, merge=False):
hash = str(img.intens.shape)
self.mask = self.gen_mask_helper(img, hash=hash)
if merge:
self.mask = self.merge_mask_helper(hash=hash)
self.eval_points = [[a] for a in np.arange(len(self.eval_points))]
@abstractmethod
def gen_mask_helper(self, img: FFT, hash=str):
pass
@abstractmethod
def merge_mask_helper(self, hash: str):
pass
def purge(self, img: FFT):
img.intens[self.mask != -1] = 0
class Rect_Evaluator(Evaluator):
# def __init__(self, points, eval_points):
# self.eval_points = eval_points
# self.points = points
# self.length = 4
def __init__(self, spots: list[tuple[np.ndarray, np.ndarray]], length: int = 6):
self.spots = spots
self.length = length
self.eval_points = []
start = 0
for sp in spots:
self.eval_points.append(np.arange(start, start+sp[0].size))
start += sp[0].size
def merge_mask_helper(self):
new_eval_points = np.arange(len(self.eval_points))
mask = self.mask
for nc, ev_points in zip(new_eval_points, self.eval_points):
maske_low = np.min(ev_points) >= self.mask
maske_high = np.max(ev_points) <= self.mask
mask[np.logical_and(maske_high, maske_low)] = nc
return mask
def gen_mask_helper(self, img: FFT):
mask = np.full_like(img.intens, -1)
count = 0
for spot_group in self.spots:
logger.debug(f"Spot: {spot_group}")
(x, y) = spot_group
x, y = img.val2pos(x, y)
for x_p, y_p in zip(x, y):
xl, yl, xu, yu = clean_bounds_offset(
x_p, y_p, self.length, img.intens.shape)
mask[xl:xu, yl:yu] = count
count += 1
return mask
#
# def main():
# np.random.seed(10)
# points = (np.random.rand(100, 2)-0.5) * 2
# voro = Voronoi_Evaluator(points, [[1],[2]])
# rect = Rect_Evaluator(points, [[1], [2]])
# Z = np.ones((1000, 1000))
# img = Image_Wrapper(Z, -5, .01, -5, .01)
# voro.extract(img)
# rect.extract(img)
#
# plt.scatter(points[[1], 0], points[[1], 1])
# plt.scatter(points[[2], 0], points[[2], 1])
# plt.imshow(img.img, extent=img.ext(), origin="lower")
# #plt.imshow(img.img, origin="lower")
# plt.show()
#
#
# if __name__ == "__main__":
# main()
# class Voronoi_Evaluator(Evaluator):
# def __init__(self, list_points):
# points = np.concatenate(list_points, axis=0)
# self.eval_points = []
# start = 0
# for l in list_points:
# stop = l.shape[0]
# self.eval_points.append(np.arange(start, start + stop))
# start += stop
# self.vor = Voronoi(points)
#
# @persist_to_file("cache_merge_voro")
# def merge_mask_helper(self):
# new_eval_points = np.arange(len(self.eval_points))
# mask = self.mask
# for nc, ev_points in zip(new_eval_points, self.eval_points):
# for num in ev_points:
# mask[self.mask == num] = nc
# return mask
#
# @persist_to_file("cache_voro")
# def gen_mask_helper(self, img: Image_Wrapper):
# mask = np.full_like(img.img, -1)
#
# counter = -1
# region_mask = self.vor.point_region
# for i in np.array(self.vor.regions, dtype=list)[region_mask]:
# counter += 1
# if -1 in i:
# continue
# if len(i) == 0:
# continue
# pts = self.vor.vertices[i]
# pts = np.stack(img.val2pos(
# pts[:, 0], pts[:, 1])).astype(np.int32).T
# if np.any(pts < 0):
# continue
# mask_2 = np.zeros_like(img.img)
# cv2.fillConvexPoly(mask_2, pts, 1)
# mask_2 = mask_2 > 0 # To convert to Boolean
# mask[mask_2] = counter
# return mask

150
clean_python/lattices.py Normal file
View File

@ -0,0 +1,150 @@
import numpy as np
from cache import timeit
from abc import ABC, abstractmethod
def deg_2_rad(winkel):
return winkel / 180.0 * np.pi
class Lattice(ABC):
@abstractmethod
def get_phase(self, index: int) -> tuple[np.ndarray, np.ndarray]:
pass
@abstractmethod
def parse_mask(self, mask: np.ndarray) -> np.ndarray:
pass
@abstractmethod
def get_phases(self) -> list[tuple[np.ndarray, np.ndarray]]:
pass
@abstractmethod
def get_spots(self) -> list[tuple[np.ndarray, np.ndarray]]:
pass
class SCC_Lattice(Lattice):
X: np.ndarray
Y: np.ndarray
def __init__(self, x_len: int, y_len: int):
x = np.arange(x_len) * 5
y = np.arange(x_len) * 4.5
self.X, self.Y = np.meshgrid(x, y)
def get_phase(self, index: int) -> tuple[np.ndarray, np.ndarray]:
return (self.X, self.Y)
def get_phases(self) -> list[tuple[np.ndarray, np.ndarray]]:
return [self.get_phase(0)]
def get_spots(self) -> list[tuple[np.ndarray, np.ndarray]]:
x = np.arange(-3, 4) * 0.2
y = np.arange(-3, 4) * 1./4.5
X, Y = np.meshgrid(x, y)
return [(X.flatten(), Y.flatten())]
def parse_mask(self, mask: np.ndarray) -> np.ndarray:
return mask
class VO2_Lattice(Lattice):
base_a_m = 5.75
base_b_m = 4.5
base_c_m = 5.38
base_c_r = 2.856
base_b_r = 4.554
base_a_r = base_b_r
alpha_m = 122.64 # degree
def __init__(self, x_len: int, y_len: int):
self.X, self.Y = self._generate_vec(x_len * 2, y_len * 2)
def parse_mask(self, mask: np.ndarray) -> np.ndarray:
maske = np.empty((mask.shape[0]*2, mask.shape[1]*2))
maske[0::2, 0::2] = mask
maske[1::2, 0::2] = mask
maske[0::2, 1::2] = mask
maske[1::2, 1::2] = mask
return maske
def get_phase(self, index: int) -> tuple[np.ndarray, np.ndarray]:
if index == 0:
return self._get_rutile()
else:
return self._get_mono()
def get_phases(self) -> list[tuple[np.ndarray, np.ndarray]]:
return [self.get_phase(0), self.get_phase(1)]
def get_spots(self) -> list[tuple[np.ndarray, np.ndarray]]:
p1 = self._reci_rutile()
p2 = self._reci_mono()
p3 = self._reci_mono_2()
return [p1, p2, p3]
pass
def _mono_2_rutile(self, c_m, a_m):
a_r = np.cos(deg_2_rad(self.alpha_m - 90)) * c_m * self.base_c_m
c_r = (a_m) * self.base_a_m + \
np.sin(deg_2_rad(self.alpha_m - 90)) * c_m * self.base_c_m
return a_r, c_r
def _get_rutile(self):
x = self.X * self.base_c_r + \
np.mod(self.Y, 4) * 0.5 * self.base_c_r
y = self.Y * 0.5 * self.base_a_r
return (x, y)
def _get_mono(self):
offset_a_m = 0.25 - 0.23947
offset_c_m = 0.02646
offset_a_r, offset_c_r = self._mono_2_rutile(offset_c_m, offset_a_m)
res = 0.05
offset_a_r = res * int(offset_a_r/res)
offset_c_r = res * int(offset_c_r/res)
x = offset_a_r + self.X * \
self.base_c_r + np.mod(self.Y, 4) * 0.5 * self.base_c_r
x[np.mod(self.X, 2) == 0] -= 2 * offset_a_r
y = offset_c_r + 0.5 * self.Y * self.base_a_r
y[np.mod(self.X, 2) == 0] -= 2 * offset_c_r
return x, y
def _generate_vec(self, x_len: int, y_len: int):
x = np.arange(x_len)
y = np.arange(y_len)
X, Y = np.meshgrid(x, y)
X[np.mod(Y, 4) == 3] = X[np.mod(Y, 4) == 3] - 1
X[np.mod(Y, 4) == 2] = X[np.mod(Y, 4) == 2] - 1
assert np.mod(x.size, 2) == 0
assert np.mod(y.size, 2) == 0
return X, Y
def _reci_rutile(self):
len_x = np.max(self.X)
len_y = np.max(self.Y)
x = np.arange(-len_x, len_x+1)
y = np.arange(-len_y, len_y+1)
X, Y = np.meshgrid(x, y)
x = X * 1./self.base_c_r
y = Y * 2 * 1./self.base_a_r + np.mod(X, 2) * 1./self.base_a_r
return x.flatten(), y.flatten()
def _reci_mono(self):
x, y = self._reci_rutile()
return x + 0.5 * 1. / self.base_c_r, y + 0.5 * 1./self.base_a_r
def _reci_mono_2(self):
x, y = self._reci_rutile()
return x - 0.5 * 1. / self.base_c_r, y + 0.5 * 1./self.base_a_r

224
clean_python/main.py Normal file
View File

@ -0,0 +1,224 @@
from lattices import SCC_Lattice, VO2_Lattice
from spin_image import SpinImage
import numpy as np
import matplotlib.pyplot as plt
import tqdm
from extractors import Rect_Evaluator
from cache import timeit
from scipy import signal
from plotter import Plotter
import scipy.fftpack as sfft
import logging
logger = logging.getLogger('fft')
# logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
def test_mixed():
fig, axs = plt.subplots(3, 3)
LEN = 40
lat = VO2_Lattice(LEN, LEN)
plot = Plotter(lat)
si = SpinImage(lat.get_phases())
mask_misk = np.ones((LEN, LEN))
ind = np.arange(mask_misk.size)
np.random.shuffle(ind)
mask_misk[np.unravel_index(ind[:800], (LEN, LEN))] = 0
si.apply_mask(lat.parse_mask(np.zeros((LEN, LEN))))
print("Clean Rutile: ", si.get_intens(
lat.parse_mask(np.zeros((LEN, LEN)))))
si.gaussian(20)
print("Rutile: ", si.get_intens(lat.parse_mask(np.zeros((LEN, LEN)))))
intens_mono = si.fft()
intens_mono.clean()
plot.plot_spins(si=si, ax_lin=axs[0, 0])
si.apply_mask(lat.parse_mask(np.ones((LEN, LEN))))
print("Clean Mono: ", si.get_intens(lat.parse_mask(np.ones((LEN, LEN)))))
si.gaussian(20)
print("Mono: ", si.get_intens(lat.parse_mask(np.ones((LEN, LEN)))))
intens_rutile = si.fft()
intens_rutile.clean()
plot.plot_spins(si=si, ax_lin=axs[0, 2])
si.apply_mask(lat.parse_mask(mask_misk))
print("Clean Mixed: ", si.get_intens(lat.parse_mask(mask_misk)))
si.gaussian(20)
print("Mixed: ", si.get_intens(lat.parse_mask(mask_misk)))
intens_mixed = si.fft()
intens_mixed.clean()
plot.plot_spins(si=si, ax_lin=axs[0, 1])
plot.plot_fft(intens_mono,
ax_log=axs[1, 0], ax_lin=axs[2, 0])
plot.plot_fft(intens_rutile,
ax_log=axs[1, 2], ax_lin=axs[2, 2])
plot.plot_fft(intens_mixed,
ax_log=axs[1, 1], ax_lin=axs[2, 1])
plt.figure()
plot.plot_fft(intens_mixed,
ax_log=plt.gca())
# Plotting cuts
def test_pdf():
LEN = 40
lat = VO2_Lattice(LEN, LEN)
plot = Plotter(lat)
si = SpinImage(lat.get_phases())
integrate = 10
out_intens = None
already_inited = False
for i in range(integrate):
mask_misk = np.ones((LEN, LEN))
ind = np.arange(mask_misk.size)
np.random.shuffle(ind)
mask_misk[np.unravel_index(ind[:800], (LEN, LEN))] = 0
si.apply_mask(lat.parse_mask(mask_misk))
si.gaussian(20)
intens = si.fft()
intens.clean()
if not already_inited:
print("Init")
rect = Rect_Evaluator(lat.get_spots())
rect.generate_mask(intens, merge=True)
out_intens = intens
already_inited = True
else:
out_intens.intens += intens.intens
out_intens = intens
rect.purge(intens)
plt.figure()
plot.plot_fft(intens, ax_log=plt.gca())
pdf = sfft.fft2(intens.intens)
pdf = sfft.fftshift(pdf)
plt.figure()
plt.imshow(np.abs(pdf))
def random(seed):
np.random.seed(seed)
LEN = 40
lat = VO2_Lattice(LEN, LEN)
maske = np.zeros((LEN, LEN))
ind = np.arange(LEN * LEN)
np.random.shuffle(ind)
rect = Rect_Evaluator(lat.get_spots())
out_rect = [[] for x in range(4)]
percentage = []
weighted_percentage = []
counter = 0
si = SpinImage(lat.get_phases())
already_inited = False
for i in tqdm.tqdm(ind):
maske[np.unravel_index(i, (LEN, LEN))] = True
counter += 1
if np.mod(counter, 100) != 0:
continue
si.apply_mask(lat.parse_mask(maske))
si.gaussian(20)
intens = si.fft()
if not already_inited:
rect.generate_mask(intens, merge=True)
already_inited = True
ir, vr = rect.extract(intens)
for lis, val in zip(out_rect, vr):
lis.append(val)
percentage.append(np.sum(maske))
[p1, p2] = si.get_intens(lat.parse_mask(maske))
weighted_percentage.append(p1/(p1+p2))
percentage = np.array(percentage)
weighted_percentage = np.array(weighted_percentage)
percentage /= np.max(percentage)
np.savez(f"random_rect_{seed}.npz",
w_percentage=weighted_percentage, percentage=percentage, out_1=out_rect[0],
out_2=out_rect[1], out_3=out_rect[2], out_4=out_rect[3])
def sample_index(p):
i = np.random.choice(np.arange(p.size), p=p.ravel())
return np.unravel_index(i, p.shape)
def ising(seed, temp=0.5):
np.random.seed(seed)
LEN = 40
lat = VO2_Lattice(LEN, LEN)
maske = np.zeros((LEN, LEN))
rect = Rect_Evaluator(lat.get_spots())
out_rect = [[] for x in range(4)]
percentage = []
weighted_percentage = []
counter = 0
si = SpinImage(lat.get_phases())
already_inited = False
for i in tqdm.tqdm(range(LEN*LEN)):
probability = np.roll(maske, 1, axis=0).astype(float)
probability += np.roll(maske, -1, axis=0).astype(float)
probability += np.roll(maske, 1, axis=1).astype(float)
probability += np.roll(maske, -1, axis=1).astype(float)
probability = np.exp(probability/temp)
probability[maske > 0] = 0
probability /= np.sum(probability)
maske[sample_index(probability)] = True
counter += 1
if np.mod(counter, 100) != 0:
continue
si.apply_mask(lat.parse_mask(maske))
si.gaussian(20)
intens = si.fft()
if not already_inited:
rect.generate_mask(intens, merge=True)
already_inited = True
ir, vr = rect.extract(intens)
for lis, val in zip(out_rect, vr):
lis.append(val)
percentage.append(np.sum(maske))
[p1, p2] = si.get_intens(lat.parse_mask(maske))
weighted_percentage.append(p1/(p1+p2))
percentage = np.array(percentage)
weighted_percentage = np.array(weighted_percentage)
percentage /= np.max(percentage)
np.savez(f"ising_{temp}_rect_{seed}.npz",
w_percentage=weighted_percentage, percentage=percentage, out_1=out_rect[0],
out_2=out_rect[1], out_3=out_rect[2], out_4=out_rect[3])
if __name__ == "__main__":
np.random.seed(1234)
# test_me()
# test_square()
# test_mixed()
# plt.show()
# random(1234)
# ising(1234)
test_pdf()
plt.show()
exit()
for i in np.random.randint(0, 10000, 5):
random(i)
ising(i, 0.5)
ising(i, 1.0)
ising(i, 1.5)
# plt.show()

2
clean_python/mypy.conf Normal file
View File

@ -0,0 +1,2 @@
[mypy]
plugins = numpy.typing.mypy_plugin

56
clean_python/plotter.py Normal file
View File

@ -0,0 +1,56 @@
import matplotlib.pyplot as plt
from spin_image import SpinImage, FFT
import matplotlib
class Plotter:
def __init__(self, lat):
self.lattice = lat
self.length_2 = 0.05
def plot_spins(self, si: SpinImage, ax_log=None, ax_lin=None):
if ax_log:
t = ax_log.imshow(
si.img,
extent=(0, si.length_x, 0, si.length_y),
norm=matplotlib.colors.LogNorm(vmin=1e-12),
cmap="viridis",
origin="lower"
)
plt.colorbar(t, ax=ax_log, extend="min")
if ax_lin:
t = ax_lin.imshow(
si.img,
extent=(0, si.length_x, 0, si.length_y),
cmap="viridis",
origin="lower"
)
plt.colorbar(t, ax=ax_lin, extend="min")
def plot_fft(self, fft: FFT, ax_log=None, ax_lin=None, evaluator=None):
if ax_log:
if evaluator:
evaluator.debug(fft)
t = ax_log.imshow(
fft.intens,
extent=fft.extents(),
norm=matplotlib.colors.LogNorm(),
cmap="viridis",
origin="lower"
)
plt.colorbar(t, ax=ax_log, extend="min")
ax_log.set_xlim(-2, 2)
ax_log.set_ylim(-2, 2)
ax_log.set_xlim(-8, 8)
ax_log.set_ylim(-8, 8)
if ax_lin:
t = ax_lin.imshow(
fft.intens,
extent=fft.extents(),
cmap="viridis",
origin="lower"
)
plt.colorbar(t, ax=ax_lin, extend="min")
ax_lin.set_xlim(-2, 2)
ax_lin.set_ylim(-2, 2)

155
clean_python/spin_image.py Normal file
View File

@ -0,0 +1,155 @@
from dataclasses import dataclass
from tools import clean_bounds, clean_bounds_offset
from cache import timeit
import scipy.signal
import scipy.fftpack as sfft
import scipy
import numpy as np
import logging
logger = logging.getLogger("fft")
@dataclass
class FFT:
freqx: np.ndarray
freqy: np.ndarray
intens: np.ndarray
def extents(self):
return (np.min(self.freqx), np.max(self.freqx), np.min(self.freqy), np.max(self.freqy))
def clean(self):
total = np.sum(self.intens)
low = np.sum(self.intens[self.intens < 1e-12])
print(f"{low}/{total}")
self.intens[self.intens < 1e-12] = 1e-12
def val2pos(self, x, y):
xind = np.searchsorted(self.freqx, x).astype(int)
yind = np.searchsorted(self.freqy, y).astype(int)
return xind, yind
class SpinImage:
resolution = 0.05
offset = 40
length_x: float
length_y: float
sigma: float
index: list[tuple[np.ndarray, np.ndarray]]
buffer: list[np.ndarray]
intens_map: np.ndarray
def make_list(self, x_pos, y_pos, x_inds, y_inds):
x_ind = np.arange(0, self.length_x, self.resolution)
y_ind = np.arange(0, self.length_y, self.resolution)
X, Y = np.meshgrid(x_ind, y_ind, indexing="ij")
out_list = []
for x, y, x_ind, y_ind in zip(
x_pos.flatten(),
y_pos.flatten(),
x_inds.flatten(),
y_inds.flatten(),
):
xl, yl, xu, yu = clean_bounds_offset(
x_ind, y_ind, self.offset, X.shape)
out_list.append(np.exp(-0.5 * ((X[xl:xu, yl:yu] - x) ** 2 +
(Y[xl:xu, yl:yu] - y) ** 2) / self.sigma**2))
# out_list.append(np.ones_like(X[xl:xu, yl:yu]))
out_list = np.array(out_list, dtype=object)
return out_list
def __init__(self, phases: list[tuple[np.ndarray, np.ndarray]], sigma=.1):
self.sigma = sigma
zero_shift_x = 10000000
zero_shift_y = 10000000
max_len_x = 0
max_len_y = 0
for x_pos, y_pos in phases:
assert x_pos.shape == y_pos.shape
zero_shift_x = np.minimum(np.min(x_pos), zero_shift_x)
zero_shift_y = np.minimum(np.min(y_pos), zero_shift_y)
max_len_x = np.maximum(np.max(x_pos), max_len_x)
max_len_y = np.maximum(np.max(y_pos), max_len_y)
self.length_x = max_len_x + (2*self.offset + 10) * self.resolution
self.length_y = max_len_y + (2*self.offset + 10) * self.resolution
self.true_pos = []
self.index = []
self.buffer = []
offset_shift = (self.offset+3)*self.resolution
for x_pos, y_pos in phases:
x_pos = x_pos - zero_shift_x + offset_shift
y_pos = y_pos - zero_shift_y + offset_shift
x_index, y_index = self._to_index(x_pos, y_pos)
self.index.append((x_index, y_index))
self.true_pos.append((x_pos, y_pos))
buffer = self.make_list(x_pos, y_pos, x_index, y_index)
self.buffer.append(buffer)
def _to_index(self, pos_x, pos_y):
x_ind = np.arange(0, self.length_x, self.resolution)
y_ind = np.arange(0, self.length_y, self.resolution)
xind = np.searchsorted(x_ind, pos_x).astype(int)
yind = np.searchsorted(y_ind, pos_y).astype(int)
return xind, yind
def _apply_mask(self, idx, buffer, mask):
logger.info("_apply_mask")
(x_indices, y_indices) = idx
assert x_indices.shape == y_indices.shape
assert x_indices.flatten().shape == mask.shape,\
f"Invalid mask: {mask.shape} != {x_indices.flatten().shape}"
for x_idx, y_idx, dat in zip(
x_indices.flatten()[mask],
y_indices.flatten()[mask],
buffer[mask]
):
logger.debug("stamp")
xl, yl, xu, yu = clean_bounds_offset(
x_idx, y_idx, self.offset, self.img.shape)
# if self.img[xl: xu, yl: yu].shape == dat.astype(np.float64).shape:
self.img[xl: xu, yl: yu] += dat.astype(np.float64)
def apply_mask(self, maske):
self.img = np.zeros(
(int(self.length_x/self.resolution), int(self.length_y/self.resolution)))
checker = False
self.intens_map = []
for counter, data in enumerate(zip(self.index, self.buffer)):
(idx, buf) = data
self._apply_mask(idx, buf, (maske == counter).flatten())
if counter in maske:
checker = True
self.intens_map.append(np.ones_like(maske))
assert checker, "No stamp set"
def fft(self):
Z_fft = sfft.fft2(self.img)
Z_shift = sfft.fftshift(Z_fft)
fft_freqx = sfft.fftfreq(self.img.shape[0], self.resolution)
fft_freqy = sfft.fftfreq(self.img.shape[1], self.resolution)
fft_freqx_clean = sfft.fftshift(fft_freqx)
fft_freqy_clean = sfft.fftshift(fft_freqy)
return FFT(freqx=fft_freqx_clean, freqy=fft_freqy_clean, intens=np.abs(Z_shift) ** 2)
def gaussian(self, sigma):
x = np.linspace(0, self.length_x, self.img.shape[0])
y = np.linspace(0, self.length_y, self.img.shape[1])
X, Y = np.meshgrid(x, y, indexing="ij")
mu_y = (self.length_x / 2.)
mu_x = (self.length_y / 2.)
z = 1 / (2 * np.pi * sigma * sigma) * \
np.exp(-((X - mu_x)**2 / (2 * sigma**2) + (Y-mu_y)**2 / (2 * sigma**2)))
self.img = np.multiply(self.img, z)
for idx, map in zip(self.true_pos, self.intens_map):
(X, Y) = idx
z = 1 / (2 * np.pi * sigma * sigma) * \
np.exp(-((X - mu_x)**2 / (2 * sigma**2) + (Y-mu_y)**2 / (2 * sigma**2)))
map *= z
def get_intens(self, mask):
intens = []
for counter, data in enumerate(self.intens_map):
intensity = np.sum(data[mask == counter])
intens.append(intensity)
return intens

42
clean_python/test.py Normal file
View File

@ -0,0 +1,42 @@
from extractors import Rect_Evaluator
from plotter import Plotter
import numpy as np
import matplotlib.pyplot as plt
from spin_image import SpinImage
from lattices import SCC_Lattice, VO2_Lattice
import logging
logger = logging.getLogger('fft')
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
def test_scc():
LEN = 40
#scc = SCC_Lattice(LEN, LEN)
scc = VO2_Lattice(LEN, LEN)
phases = scc.get_phases()
logger.info(f"Phases: {phases}")
si = SpinImage(phases)
re = Rect_Evaluator(scc.get_spots())
si.apply_mask(scc.parse_mask(np.ones((LEN, LEN))))
si.gaussian(20)
fft = si.fft()
fft.clean()
re.generate_mask(fft, merge=True)
fig, axs = plt.subplots(3, 2)
plot = Plotter(scc)
plot.plot_spins(si, ax_lin=axs[0, 0])
#plot.plot_fft(fft, ax_lin=axs[1, 0], ax_log=axs[2, 0])
plot.plot_fft(fft, ax_lin=axs[1, 0], ax_log=axs[2, 0], evaluator=re)
plt.show()
test_scc()
plt.show()

45
clean_python/tools.py Normal file
View File

@ -0,0 +1,45 @@
import numpy as np
def clean_bounds(xl, yl, xu, yu, shape):
if xl < 0:
xl = 0
if yl < 0:
yl = 0
if xu > shape[0]:
xu = shape[0]
if yu > shape[1]:
yu = shape[1]
return xl, yl, xu, yu
def clean_bounds_offset(x: int, y: int, offset: int, shape: tuple[int, int]):
xl = x - offset
xu = x + offset + 1
yl = y - offset
yu = y + offset + 1
# if isinstance(x, np.ndarray):
# xl[xl < 0] = 0
# xu[xu > shape[0]] = shape[0]
# xl = xl.astype(np.int64)
# xu = xu.astype(np.int64)
# else:
if xl < 0:
xl = 0
if xu > shape[0]:
xu = shape[0]
# if isinstance(y, np.ndarray):
# yl[yl < 0] = 0
# yu[yu > shape[1]] = shape[1]
# yl = yl.astype(np.int64)
# yu = yu.astype(np.int64)
# else:
if yl < 0:
yl = 0
if yu > shape[1]:
yu = shape[1]
return xl, yl, xu, yu