Skip to content

Commit

Permalink
updated Fit1D class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Nov 18, 2024
1 parent 226b933 commit 0aa4bd4
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 137 deletions.
226 changes: 94 additions & 132 deletions src/tavi/data/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@
from lmfit import Parameters, models

from tavi.data.scan_data import ScanData1D
from tavi.plotter import Plot1D


class FitData1D(object):

def __init__(
self,
x: np.ndarray,
y: np.ndarray,
) -> None:

self.x = x
self.y = y
self.fmt: dict = {}


class Fit1D(object):
Expand Down Expand Up @@ -41,7 +53,6 @@ def __init__(
):
"""initialize a fit model, mask based on fit_range if given"""

self.NUM_PTS: int = 100
self.x: np.ndarray = data.x
self.y: np.ndarray = data.y
self.err: Optional[np.ndarray] = data.err
Expand All @@ -51,7 +62,8 @@ def __init__(
self.pars = Parameters()
self._num_backgrounds = 0
self._num_signals = 0
self.fit_result = None
self.result = None
self.fit_data: Optional[FitData1D] = None

self.PLOT_SEPARATELY = False

Expand All @@ -67,141 +79,91 @@ def set_range(self, fit_range: tuple[float, float]):
if self.err is not None:
self.err = self.err[mask]

@property
def x_plot(self):
return np.linspace(self.x.min(), self.x.max(), num=self.NUM_PTS)

def add_background(
self,
model: Literal["Constant", "Linear", "Quadratic", "Polynomial", "Exponential", "PowerLaw"] = "Constant",
values=None,
vary=None,
mins=None,
maxs=None,
exprs=None,
):
"""Set the model for background
Args:
model (str): Constant, Linear, Quadratic, Polynomial,
Exponential, PowerLaw
p0 (tuple | None): inital parameters
min (tuple | None): minimum
max (tuple | None): maximum
fixed (tuple | None): tuple of flags
expr (tuple| None ): constraint expressions
"""
self._num_backgrounds += 1

# add prefix if more than one background
if self._num_backgrounds > 1:
prefix = f"b{self._num_backgrounds}_"
else:
prefix = ""

model = Fit1D.models[model](prefix=prefix, nan_policy="propagate")
param_names = model.param_names
# guess initials
pars = model.guess(self.y, x=self.x)

# overwrite with user input
if values is not None:
for idx, v in enumerate(values):
if v is not None:
pars[param_names[idx]].set(value=v)

if vary is not None:
for idx, v in enumerate(vary):
if v is not None:
pars[param_names[idx]].set(vary=v)

if mins is not None:
for idx, v in enumerate(mins):
if v is not None:
pars[param_names[idx]].set(min=v)
if maxs is not None:
for idx, v in enumerate(maxs):
if v is not None:
pars[param_names[idx]].set(max=v)

if exprs is not None:
for idx, v in enumerate(exprs):
if v is not None:
pars[param_names[idx]].set(expr=v)

for param_name in param_names:
self.pars.add(pars[param_name])

self.background_models.append(model)
@staticmethod
def _add_model(model, prefix):
model = Fit1D.models[model]
return model(prefix=prefix, nan_policy="propagate")

def add_signal(
self,
model="Gaussian",
values=None,
vary=None,
mins=None,
maxs=None,
exprs=None,
model_name: Literal[
"Gaussian", "Lorentzian", "Voigt", "PseudoVoigt", "DampedOscillator", "DampedHarmonicOscillator"
],
):
"""Set the model for signal
Args:
model (str): Constant, Linear, Quadratic, Polynomial,
Exponential, PowerLaw
p0 (tuple | None): inital parameters
min (tuple | None): minimum
max (tuple | None): maximum
expr (str| None ): constraint expression
"""
self._num_signals += 1
prefix = f"s{self._num_signals}_"
model = Fit1D.models[model](prefix=prefix, nan_policy="propagate")
param_names = model.param_names
# guess initials
pars = model.guess(self.y, x=self.x)

# overwrite with user input
if values is not None:
for idx, v in enumerate(values):
if v is not None:
pars[param_names[idx]].set(value=v)

if vary is not None:
for idx, v in enumerate(vary):
if v is not None:
pars[param_names[idx]].set(vary=v)

if mins is not None:
for idx, v in enumerate(mins):
if v is not None:
pars[param_names[idx]].set(min=v)
if maxs is not None:
for idx, v in enumerate(maxs):
if v is not None:
pars[param_names[idx]].set(max=v)

if exprs is not None:
for idx, v in enumerate(exprs):
if v is not None:
pars[param_names[idx]].set(expr=v)

for param_name in param_names:
self.pars.add(pars[param_name])
self.signal_models.append(model)

def perform_fit(self) -> None:
model = np.sum(self.signal_models)

if self._num_backgrounds > 0:
model += np.sum(self.background_models)

if self.err is None:
out = model.fit(self.y, self.pars, x=self.x)
self.signal_models.append(Fit1D._add_model(model_name, prefix))

def add_background(
self, model_name: Literal["Constant", "Linear", "Quadratic", "Polynomial", "Exponential", "PowerLaw"]
):
self._num_backgrounds += 1
prefix = f"b{self._num_backgrounds}_"
self.background_models.append(Fit1D._add_model(model_name, prefix))

@staticmethod
def _get_model_params(models):
params = []
for model in models:
params.append(model.param_names)
return params

def get_signal_params(self):
return Fit1D._get_model_params(self.signal_models)

def get_background_params(self):
return Fit1D._get_model_params(self.background_models)

def guess(self):
pars = Parameters()
for signal in self.signal_models:
pars += signal.guess(self.y, x=self.x)
for bkg in self.background_models:
pars += bkg.guess(self.y, x=self.x)
self.pars = pars
return pars

@property
def x_to_plot(self):
return

def eval(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> FitData1D:
mod = self.signal_models[0]
if (sz := len(self.signal_models)) > 1:
for i in range(1, sz):
mod += self.signal_models[i]

for bkg in self.background_models:
mod += bkg

if num_of_pts is None:
x_to_plot = self.x
elif isinstance(num_of_pts, int):
x_to_plot = np.linspace(self.x.min(), self.x.max(), num=num_of_pts)
else:
raise ValueError(f"num_of_points={num_of_pts} needs to be an integer.")
y_to_plot = mod.eval(pars, x=x_to_plot)
return FitData1D(x_to_plot, y_to_plot)

def fit(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> FitData1D:
mod = self.signal_models[0]
if (sz := len(self.signal_models)) > 1:
for i in range(1, sz):
mod += self.signal_models[i]

for bkg in self.background_models:
mod += bkg

result = mod.fit(self.y, pars, x=self.x, weights=self.err)
self.result = result

if num_of_pts is None:
x_to_plot = self.x
elif isinstance(num_of_pts, int):
x_to_plot = np.linspace(self.x.min(), self.x.max(), num=num_of_pts)
else:
out = model.fit(self.y, self.pars, x=self.x, weights=self.err)
raise ValueError(f"num_of_points={num_of_pts} needs to be an integer.")

self.result = out
self.y_plot = model.eval(out.params, x=self.x_plot)
y_to_plot = mod.eval(result.params, x=x_to_plot)

self.fit_report = out.fit_report(min_correl=0.25)
self.fit_plot = Plot1D(x=self.x_plot, y=self.y_plot)
return FitData1D(x_to_plot, y_to_plot)
Loading

0 comments on commit 0aa4bd4

Please sign in to comment.