Skip to content

Commit

Permalink
feat: add_int_error option for fitfraction
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyi15 committed Nov 4, 2024
1 parent ea42a58 commit 60d5326
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 11 deletions.
4 changes: 2 additions & 2 deletions tf_pwa/config_loader/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,10 +1016,10 @@ def set_params(self, params, neglect_params=None):
neglect_params = self._neglect_when_set_params
if len(neglect_params) != 0:
for v in params:
if v in self._neglect_when_set_params:
if v in neglect_params:
warnings.warn(
"Neglect {} when setting params.".format(
neglect_params
[i for i in params if i in neglect_params]
)
)
del ret[v]
Expand Down
4 changes: 3 additions & 1 deletion tf_pwa/data_trans/helicity_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def get_all_mass(self, replace_mass):
for i in self.decay_chain:
for j in [i.core] + list(i.outs):
if j not in ms:
if str(j) in replace_mass:
if j in replace_mass:
ms[j] = replace_mass[j]
elif str(j) in replace_mass:
ms[j] = tf.convert_to_tensor(
replace_mass[str(j)], tf.float64
)
Expand Down
25 changes: 18 additions & 7 deletions tf_pwa/fitfractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,26 @@
from tf_pwa.data import LazyCall, data_split


def eval_integral(
def _eval_integral(
f, data, var, weight=None, args=(), no_grad=False, kwargs=None
):
kwargs = {} if kwargs is None else kwargs
weight = 1.0 if weight is None else weight
if no_grad:
ret = tf.reduce_sum(f(data, *args, **kwargs) * weight)
fx = f(data, *args, **kwargs) * weight
ret = tf.reduce_sum(fx)
ret_grad = np.zeros((len(var),))
else:
with tf.GradientTape() as tape:
ret = tf.reduce_sum(f(data, *args, **kwargs) * weight)
fx = f(data, *args, **kwargs) * weight
ret = tf.reduce_sum(fx)
ret_grad = tape.gradient(ret, var, unconnected_gradients="zero")
if len(ret_grad) == 0:
ret_grad = np.array([])
else:
ret_grad = np.stack([i.numpy() for i in ret_grad])
return ret.numpy(), ret_grad
int_square = tf.reduce_sum(fx**2)
return ret.numpy(), ret_grad, int_square.numpy()


def force_list(x):
Expand All @@ -39,6 +42,7 @@ def __init__(self, amp, res):
self.res = res
self.cached_int = {}
self.cached_grad = {}
self.cached_square = {}
self.cached_int_total = 0.0
self.cached_grad_total = np.zeros((self.n_var,))
self.error_matrix = np.diag(np.zeros((self.n_var,)))
Expand All @@ -55,6 +59,7 @@ def init_res_table(self):
else:
name = (str(self.res[i]), str(self.res[j]))
self.cached_int[name] = 0.0
self.cached_square[name] = 0.0
self.cached_grad[name] = np.zeros_like((self.n_var,))

def integral(self, mcdata, *args, batch=None, no_grad=False, **kwargs):
Expand All @@ -71,7 +76,7 @@ def append_int(self, mcdata, *args, weight=None, no_grad=False, **kwargs):
mcdata = mcdata.eval()
if weight is None:
weight = mcdata.get("weight", 1.0)
int_mc, g_int_mc = eval_integral(
int_mc, g_int_mc, _ = _eval_integral(
self.amp,
mcdata,
var=self.var,
Expand All @@ -92,7 +97,7 @@ def append_int(self, mcdata, *args, weight=None, no_grad=False, **kwargs):
else:
name = (str(self.res[i]), str(self.res[j]))
amp_tmp.set_used_res(fl(self.res[i]) + fl(self.res[j]))
int_tmp, g_int_tmp = eval_integral(
int_tmp, g_int_tmp, int_square = _eval_integral(
amp_tmp,
mcdata,
var=self.var,
Expand All @@ -102,6 +107,9 @@ def append_int(self, mcdata, *args, weight=None, no_grad=False, **kwargs):
)
self.cached_int[name] = self.cached_int[name] + int_tmp
self.cached_grad[name] = self.cached_grad[name] + g_int_tmp
self.cached_square[name] = (
self.cached_square[name] + int_square
)

self.amp.set_used_res(cahced_res)

Expand Down Expand Up @@ -142,7 +150,7 @@ def get_frac_grad(self, sum_diag=True):
)
return fit_frac, g_fit_frac

def get_frac(self, error_matrix=None, sum_diag=True):
def get_frac(self, error_matrix=None, sum_diag=True, add_int_error=False):
if error_matrix is None:
error_matrix = self.error_matrix
fit_frac, g_fit_frac = self.get_frac_grad(sum_diag=sum_diag)
Expand All @@ -152,6 +160,9 @@ def get_frac(self, error_matrix=None, sum_diag=True):
for k, v in g_fit_frac.items():
e = np.sqrt(np.dot(np.dot(error_matrix, v), v))
fit_frac_err[k] = e
if add_int_error and k in self.cached_square:
scale = 1 / self.cached_int_total**2
fit_frac_err[k] = np.sqrt(e**2 + self.cached_square[k] * scale)
return fit_frac, fit_frac_err

def __iter__(self):
Expand Down
2 changes: 1 addition & 1 deletion tf_pwa/tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def pw_f(x, **kwargs):
fit_result.save_as("toy_data/final_params.json")
fit_frac, frac_err = toy_config.cal_fitfractions()
fit_frac, frac_err = toy_config.cal_fitfractions(method="new")
fit_frac, frac_err = toy_config.cal_fitfractions(
fit_frac_obj = toy_config.cal_fitfractions(
method="new", res=["R_BC", ["R_BD", "R_CD"]]
)
save_frac_csv("toy_data/fit_frac.csv", fit_frac)
Expand Down

0 comments on commit 60d5326

Please sign in to comment.