Skip to content

Commit

Permalink
Merge pull request #155 from jiangyi15/fit_zero_params
Browse files Browse the repository at this point in the history
Misc: some updates
  • Loading branch information
jiangyi15 authored Nov 5, 2024
2 parents 4263161 + 1983fb7 commit 716d3a8
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 37 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ tf_pwa/version.py
*.npy
*.json
*.log
*.yml
*.png
*.pdf
*.pptx
Expand Down
7 changes: 5 additions & 2 deletions examples/ex2_particle_amplitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
"R1": {"J": 0, "P": 1, "mass": 1.0, "width": 0.07},
"R2": {"J": 1, "P": -1, "mass": 1.225, "width": 0.08},
}

a, b, c, d = [get_particle(i, J=0, P=-1) for i in "ABCD"]
m_A, m_B, m_C, m_D = 1.8, 0.18, 0.18, 0.18
a, b, c, d = [
get_particle(i, J=0, P=-1, mass=m)
for i, m in zip("ABCD", [m_A, m_B, m_C, m_D])
]
r1, r2, r3 = [get_particle(i, **resonances[i]) for i in resonances.keys()]


Expand Down
47 changes: 38 additions & 9 deletions examples/ex_params_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,52 @@
print(a2_x.numpy(), pt.get_error(a2_x).numpy())

# %%
# We can also calculate some more complex examples, such as the ratio in mass range (0.75, 0.85) over full phace space.
# Uncertainties of fit fractions
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# We can also calculate some more complex examples, such as the fit fractions of all in `C+D`.
# Even further, we can get the error of error in the meaning of error propagation.

from tf_pwa.data import data_mask

m_R2 = phsp.get_mass("(B, D)")
cut_cond = (m_R2 < 0.85) & (m_R2 > 0.75)

amp = config.get_amplitude()

with config.params_trans() as pt1:
with config.params_trans() as pt:
int_mc = tf.reduce_sum(amp(phsp))
cut_phsp = data_mask(phsp, cut_cond)
cut_int_mc = tf.reduce_sum(amp(cut_phsp))
ratio = cut_int_mc / int_mc
with amp.temp_used_res(["R1_a", "R1_b"]):
part_int_mc = tf.reduce_sum(amp(phsp))
ratio = part_int_mc / int_mc
error = pt.get_error(ratio)

print(ratio.numpy(), "+/-", error.numpy())
print(error.numpy(), "+/-", pt1.get_error(error).numpy())

# %%
# For large data size it would be some problem named OOM (out of memory).
# TFPWA provide `vm.batch_sum_var` to do sum of large samples

int_mc_v = config.vm.batch_sum_var(amp, phsp, batch=5000)

with amp.temp_used_res(["R1_a", "R1_b"]):
part_int_mc_v = config.vm.batch_sum_var(amp, phsp, batch=5000)

# %%
# It will store the pre-calculated gradients as

print(int_mc_v.grad, part_int_mc_v.grad)

# %%
# Then, we can use it as a function to do error propagation:

with config.params_trans() as pt:
ratio = part_int_mc_v() / int_mc_v()
error = pt.get_error(ratio)

print(ratio.numpy(), "+/-", error.numpy())

# %%
# Besides the error propagation, there would be some additional uncertainties.
# For example, the uncertainty from the integration sample size is often defined as the sum of square as

with amp.temp_used_res(["R1_a", "R1_b"]):
int_square = tf.reduce_sum((amp(phsp) / int_mc) ** 2)

print(ratio.numpy(), "+/-", error.numpy(), "+/-", tf.sqrt(int_square).numpy())
2 changes: 1 addition & 1 deletion fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def print_fit_result_roofit(config, fit_result):
value = fit_result.params
params_name = config.vm.trainable_vars
n_par = len(params_name)
name_size = max(len(i) for i in params_name)
name_size = max([5] + [len(i) for i in params_name])
# fcn = config.get_fcn()
# _, grad = fcn.nll_grad(fit_result.params)
# edm = np.dot(np.dot(config.inv_he, grad), grad)
Expand Down
5 changes: 5 additions & 0 deletions tf_pwa/amp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ def add_var(self, names, is_complex=False, shape=(), **kwargs):
"""
if not hasattr(self, "_variables_map"):
self._variables_map = {}
if True:
default_config = getattr(self, "default_params", {}).get(names, {})
if isinstance(default_config, (float, int)):
default_config = {"value": default_config}
kwargs.update(default_config)
name = self.get_variable_name(names)
var = Variable(name, shape, is_complex, **kwargs)
self._variables_map[names] = var
Expand Down
15 changes: 12 additions & 3 deletions tf_pwa/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,24 @@ def fit_fractions(
:param mcdata: MCdata array.
:param inv_he: The inverse of Hessian matrix. If it's not given, the errors will not be calculated.
:return frac: Dictionary of fit fractions for each resonance.
:return err_frac: Dictionary of their errors. If ``inv_he`` is ``None``, it will be a dictionary of ``None``.
:return err_frac: Dictionary of their errors. If ``inv_he`` is ``None``, it will be a dictionary of ``{}``.
"""
if params is None:
params = {}
err_frac = {}
if method == "old":
with amp.temp_params(params):
frac, grad = cal_fitfractions(amp, mcdata, res=res, batch=batch)
if inv_he is not None:
with amp.temp_params(params):
frac, grad = cal_fitfractions(
amp, mcdata, res=res, batch=batch
)
for i in frac:
err_frac[i] = np.sqrt(np.dot(np.dot(inv_he, grad[i]), grad[i]))
else:
with amp.temp_params(params):
frac = cal_fitfractions_no_grad(
amp, mcdata, res=res, batch=batch
)
return frac, err_frac
else:
ret = FitFractions(amp, res)
Expand Down Expand Up @@ -289,6 +296,8 @@ def cal_hesse_error(
:return hesse_error: List of errors.
:return inv_he: The inverse Hessian matrix.
"""
if len(fcn.vm.trainable_vars) == 0:
return {}, None
t = time.time()
nll, g, h = fcn.nll_grad_hessian(
params
Expand Down
12 changes: 6 additions & 6 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,6 @@ def add_constraints(self, amp):
constrains = {}
self.add_decay_constraints(amp, constrains.get("decay", {}))
self.add_particle_constraints(amp, constrains.get("particle", {}))
self.add_fix_var_constraints(amp, constrains.get("fix_var", {}))
self.add_free_var_constraints(amp, constrains.get("free_var", []))
self.add_var_range_constraints(amp, constrains.get("var_range", {}))
self.add_var_equal_constraints(amp, constrains.get("var_equal", []))
self.add_pre_trans_constraints(amp, constrains.get("pre_trans", None))
Expand All @@ -273,6 +271,8 @@ def add_constraints(self, amp):
self.add_gauss_constr_constraints(
amp, constrains.get("gauss_constr", {})
)
self.add_fix_var_constraints(amp, constrains.get("fix_var", {}))
self.add_free_var_constraints(amp, constrains.get("free_var", []))
for k, v in self.extra_constrains.items():
v(amp, constrains.get(k, {}))

Expand Down Expand Up @@ -832,8 +832,8 @@ def get_params_error(
params = {}
if correct_params is None:
correct_params = []
if method is None:
method = "correct"
if len(correct_params) > 0 and method is None:
method = "correct"
if hasattr(params, "params"):
params = getattr(params, "params")
if not using_cached:
Expand Down Expand Up @@ -1016,10 +1016,10 @@ def set_params(self, params, neglect_params=None):
neglect_params = self._neglect_when_set_params
if len(neglect_params) != 0:
for v in params:
if v in self._neglect_when_set_params:
if v in neglect_params:
warnings.warn(
"Neglect {} when setting params.".format(
neglect_params
[i for i in params if i in neglect_params]
)
)
del ret[v]
Expand Down
14 changes: 11 additions & 3 deletions tf_pwa/data_trans/helicity_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from tf_pwa.amp import get_relative_p
from tf_pwa.angle import LorentzVector as lv
from tf_pwa.data import data_index


class HelicityAngle1:
Expand Down Expand Up @@ -73,7 +74,9 @@ def get_all_mass(self, replace_mass):
for i in self.decay_chain:
for j in [i.core] + list(i.outs):
if j not in ms:
if str(j) in replace_mass:
if j in replace_mass:
ms[j] = replace_mass[j]
elif str(j) in replace_mass:
ms[j] = tf.convert_to_tensor(
replace_mass[str(j)], tf.float64
)
Expand All @@ -90,7 +93,9 @@ def generate_p_mass(self, name, m, random=False):
for i in self.decay_chain:
data[i] = {}
data[i]["|p|"] = get_relative_p(
ms[i.core], ms[i.outs[0]], ms[i.outs[1]]
data_index(ms, i.core),
data_index(ms, i.outs[0]),
data_index(ms, i.outs[1]),
)
if random:
costheta = np.random.random(m.shape) * 2 - 1
Expand All @@ -108,10 +113,13 @@ def generate_p_mass(self, name, m, random=False):
def build_data(self, ms, costheta, phi):
"""generate monmentum with M_name = m"""
data = {}
ms = self.get_all_mass(ms)
for j, i in enumerate(self.decay_chain):
data[i] = {}
data[i]["|p|"] = get_relative_p(
ms[i.core], ms[i.outs[0]], ms[i.outs[1]]
data_index(ms, i.core),
data_index(ms, i.outs[0]),
data_index(ms, i.outs[1]),
)
costheta_i = costheta[j]
phi_i = phi[j]
Expand Down
7 changes: 6 additions & 1 deletion tf_pwa/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,12 @@ def callback(x):
f_g = Cached_FG(f_g, grad_scale=grad_scale)
# print(f_g)
x0 = np.array(fcn.vm.get_all_val(True))
# print(x0, fcn.vm.get_all_dic())
if len(x0) == 0:
min_nll, _ = f_g(x0)
params = fcn.get_params()
return FitResult(
params, fcn, min_nll, ndf=0, success=True, hess_inv=None
)
# s = minimize(f_g, x0, method='trust-constr', jac=True, hess=BFGS(), options={'gtol': 1e-4, 'disp': True})
if method == "test":
try:
Expand Down
42 changes: 31 additions & 11 deletions tf_pwa/fitfractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,32 @@
from tf_pwa.data import LazyCall, data_split


def eval_integral(
def _eval_integral(
f, data, var, weight=None, args=(), no_grad=False, kwargs=None
):
kwargs = {} if kwargs is None else kwargs
weight = 1.0 if weight is None else weight
if no_grad:
ret = tf.reduce_sum(f(data, *args, **kwargs) * weight)
fx = f(data, *args, **kwargs) * weight
ret = tf.reduce_sum(fx)
ret_grad = np.zeros((len(var),))
else:
with tf.GradientTape() as tape:
ret = tf.reduce_sum(f(data, *args, **kwargs) * weight)
fx = f(data, *args, **kwargs) * weight
ret = tf.reduce_sum(fx)
ret_grad = tape.gradient(ret, var, unconnected_gradients="zero")
ret_grad = np.stack([i.numpy() for i in ret_grad])
return ret.numpy(), ret_grad
if len(ret_grad) == 0:
ret_grad = np.array([])
else:
ret_grad = np.stack([i.numpy() for i in ret_grad])
int_square = tf.reduce_sum(fx**2)
return ret.numpy(), ret_grad, int_square.numpy()


def force_list(x):
if isinstance(x, (list, tuple)):
return x
return [x]


class FitFractions:
Expand All @@ -30,6 +42,7 @@ def __init__(self, amp, res):
self.res = res
self.cached_int = {}
self.cached_grad = {}
self.cached_square = {}
self.cached_int_total = 0.0
self.cached_grad_total = np.zeros((self.n_var,))
self.error_matrix = np.diag(np.zeros((self.n_var,)))
Expand All @@ -46,6 +59,7 @@ def init_res_table(self):
else:
name = (str(self.res[i]), str(self.res[j]))
self.cached_int[name] = 0.0
self.cached_square[name] = 0.0
self.cached_grad[name] = np.zeros_like((self.n_var,))

def integral(self, mcdata, *args, batch=None, no_grad=False, **kwargs):
Expand All @@ -62,7 +76,7 @@ def append_int(self, mcdata, *args, weight=None, no_grad=False, **kwargs):
mcdata = mcdata.eval()
if weight is None:
weight = mcdata.get("weight", 1.0)
int_mc, g_int_mc = eval_integral(
int_mc, g_int_mc, _ = _eval_integral(
self.amp,
mcdata,
var=self.var,
Expand All @@ -74,15 +88,16 @@ def append_int(self, mcdata, *args, weight=None, no_grad=False, **kwargs):
self.cached_grad_total += g_int_mc
cahced_res = self.amp.used_res
amp_tmp = self.amp
fl = force_list
for i in range(len(self.res)):
for j in range(i, -1, -1):
if i == j:
name = str(self.res[i])
amp_tmp.set_used_res([self.res[i]])
amp_tmp.set_used_res(fl(self.res[i]))
else:
name = (str(self.res[i]), str(self.res[j]))
amp_tmp.set_used_res([self.res[i], self.res[j]])
int_tmp, g_int_tmp = eval_integral(
amp_tmp.set_used_res(fl(self.res[i]) + fl(self.res[j]))
int_tmp, g_int_tmp, int_square = _eval_integral(
amp_tmp,
mcdata,
var=self.var,
Expand All @@ -92,6 +107,9 @@ def append_int(self, mcdata, *args, weight=None, no_grad=False, **kwargs):
)
self.cached_int[name] = self.cached_int[name] + int_tmp
self.cached_grad[name] = self.cached_grad[name] + g_int_tmp
self.cached_square[name] = (
self.cached_square[name] + int_square
)

self.amp.set_used_res(cahced_res)

Expand Down Expand Up @@ -130,10 +148,9 @@ def get_frac_grad(self, sum_diag=True):
g_fit_frac["sum_diag"] = sum(
[g_fit_frac[str(i)] for i in self.res]
)
print(fit_frac)
return fit_frac, g_fit_frac

def get_frac(self, error_matrix=None, sum_diag=True):
def get_frac(self, error_matrix=None, sum_diag=True, add_int_error=False):
if error_matrix is None:
error_matrix = self.error_matrix
fit_frac, g_fit_frac = self.get_frac_grad(sum_diag=sum_diag)
Expand All @@ -143,6 +160,9 @@ def get_frac(self, error_matrix=None, sum_diag=True):
for k, v in g_fit_frac.items():
e = np.sqrt(np.dot(np.dot(error_matrix, v), v))
fit_frac_err[k] = e
if add_int_error and k in self.cached_square:
scale = 1 / self.cached_int_total**2
fit_frac_err[k] = np.sqrt(e**2 + self.cached_square[k] * scale)
return fit_frac, fit_frac_err

def __iter__(self):
Expand Down
Loading

0 comments on commit 716d3a8

Please sign in to comment.