Skip to content

Commit d1d89ca

Browse files
committed
add all steps to traj
1 parent 6022df3 commit d1d89ca

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed
483 KB
Loading

examples/integration/tsadar/fitter.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
from numpy import ndarray
33
from scipy.optimize import minimize
4+
from matplotlib.animation import FuncAnimation
5+
import matplotlib.pyplot as plt
46

57
from tesseract_core import Tesseract
68

7-
from .animate import animate
8-
99
tesseract_url = "http://localhost:8000" # Change this to the correct address
1010

1111
tsadar = Tesseract(url=tesseract_url)
@@ -45,12 +45,6 @@ def to_dict(params: ndarray) -> dict:
4545
}
4646

4747

48-
def mse(pred: ndarray, true: ndarray) -> float:
49-
"""Mean Squared Error."""
50-
mse = np.mean(np.square(pred - true))
51-
return mse
52-
53-
5448
def jacobian(parameters: np.ndarray, true_electron_spectrum: np.ndarray) -> np.ndarray:
5549
"""Compute the gradient of the MSE loss function with respect to the parameters."""
5650
# Compute the gradient
@@ -97,9 +91,16 @@ def jacobian(parameters: np.ndarray, true_electron_spectrum: np.ndarray) -> np.n
9791
trajectory.append(electron_spectrum)
9892

9993

94+
def mse(pred: ndarray, true: ndarray) -> float:
95+
"""Mean Squared Error."""
96+
mse = np.mean(np.square(pred - true))
97+
trajectory.append(pred)
98+
print(f"loss {mse}")
99+
return mse
100+
100101
def callback(xk):
101102
electron_spectrum = tsadar.apply(to_dict(xk))["electron_spectrum"]
102-
trajectory.append(electron_spectrum)
103+
# trajectory.append(electron_spectrum)
103104
print(f"loss: {mse(electron_spectrum, true_electron_spectrum)}")
104105

105106

@@ -111,7 +112,36 @@ def callback(xk):
111112
x0=parameters,
112113
method="L-BFGS-B",
113114
options={"maxiter": 200, "maxls": 10},
114-
callback=callback,
115+
# callback=callback,
115116
)
116117

117-
animate(trajectory, true_electron_spectrum)
118+
119+
def animate(trajectory : list[np.ndarray], true_electron_spectrum: np.ndarray):
120+
121+
n = len(trajectory)
122+
123+
optim_steps = np.linspace(0, n, n + 1)
124+
125+
# repeat last trajectory point
126+
for _ in range(10):
127+
trajectory.append(trajectory[-1])
128+
optim_steps = np.append(optim_steps, n)
129+
fig, ax = plt.subplots()
130+
131+
132+
def update(i):
133+
ax.clear()
134+
ax.set_xlabel("Wavelength")
135+
ax.set_ylabel("Intensity")
136+
ax.set_title(f"Optimization step {int(optim_steps[i])}")
137+
138+
ax.plot(trajectory[i], label="Fit")
139+
ax.plot(true_electron_spectrum, label="True")
140+
ax.legend()
141+
ax.grid()
142+
143+
144+
ani = FuncAnimation(fig, update, frames=len(trajectory), repeat=False)
145+
ani.save("fit_trajectory.gif", writer="imagemagick", fps=3)
146+
147+
animate(trajectory, true_electron_spectrum)

0 commit comments

Comments
 (0)