From 60d53261e9f3c5a8cda0130a60ce884b5fe51b0f Mon Sep 17 00:00:00 2001 From: jiangyi15 Date: Tue, 5 Nov 2024 00:17:22 +0800 Subject: [PATCH] feat: add_int_error option for fitfraction --- tf_pwa/config_loader/config_loader.py | 4 ++-- tf_pwa/data_trans/helicity_angle.py | 4 +++- tf_pwa/fitfractions.py | 25 ++++++++++++++++++------- tf_pwa/tests/test_full.py | 2 +- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index d0f2bf2..020c38b 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -1016,10 +1016,10 @@ def set_params(self, params, neglect_params=None): neglect_params = self._neglect_when_set_params if len(neglect_params) != 0: for v in params: - if v in self._neglect_when_set_params: + if v in neglect_params: warnings.warn( "Neglect {} when setting params.".format( - neglect_params + [i for i in params if i in neglect_params] ) ) del ret[v] diff --git a/tf_pwa/data_trans/helicity_angle.py b/tf_pwa/data_trans/helicity_angle.py index 57c1c7d..219532f 100644 --- a/tf_pwa/data_trans/helicity_angle.py +++ b/tf_pwa/data_trans/helicity_angle.py @@ -74,7 +74,9 @@ def get_all_mass(self, replace_mass): for i in self.decay_chain: for j in [i.core] + list(i.outs): if j not in ms: - if str(j) in replace_mass: + if j in replace_mass: + ms[j] = replace_mass[j] + elif str(j) in replace_mass: ms[j] = tf.convert_to_tensor( replace_mass[str(j)], tf.float64 ) diff --git a/tf_pwa/fitfractions.py b/tf_pwa/fitfractions.py index 6f4cabe..cd472f5 100644 --- a/tf_pwa/fitfractions.py +++ b/tf_pwa/fitfractions.py @@ -6,23 +6,26 @@ from tf_pwa.data import LazyCall, data_split -def eval_integral( +def _eval_integral( f, data, var, weight=None, args=(), no_grad=False, kwargs=None ): kwargs = {} if kwargs is None else kwargs weight = 1.0 if weight is None else weight if no_grad: - ret = tf.reduce_sum(f(data, *args, **kwargs) * weight) + fx = f(data, *args, **kwargs) * weight + ret = tf.reduce_sum(fx) ret_grad = np.zeros((len(var),)) else: with tf.GradientTape() as tape: - ret = tf.reduce_sum(f(data, *args, **kwargs) * weight) + fx = f(data, *args, **kwargs) * weight + ret = tf.reduce_sum(fx) ret_grad = tape.gradient(ret, var, unconnected_gradients="zero") if len(ret_grad) == 0: ret_grad = np.array([]) else: ret_grad = np.stack([i.numpy() for i in ret_grad]) - return ret.numpy(), ret_grad + int_square = tf.reduce_sum(fx**2) + return ret.numpy(), ret_grad, int_square.numpy() def force_list(x): @@ -39,6 +42,7 @@ def __init__(self, amp, res): self.res = res self.cached_int = {} self.cached_grad = {} + self.cached_square = {} self.cached_int_total = 0.0 self.cached_grad_total = np.zeros((self.n_var,)) self.error_matrix = np.diag(np.zeros((self.n_var,))) @@ -55,6 +59,7 @@ def init_res_table(self): else: name = (str(self.res[i]), str(self.res[j])) self.cached_int[name] = 0.0 + self.cached_square[name] = 0.0 self.cached_grad[name] = np.zeros_like((self.n_var,)) def integral(self, mcdata, *args, batch=None, no_grad=False, **kwargs): @@ -71,7 +76,7 @@ def append_int(self, mcdata, *args, weight=None, no_grad=False, **kwargs): mcdata = mcdata.eval() if weight is None: weight = mcdata.get("weight", 1.0) - int_mc, g_int_mc = eval_integral( + int_mc, g_int_mc, _ = _eval_integral( self.amp, mcdata, var=self.var, @@ -92,7 +97,7 @@ def append_int(self, mcdata, *args, weight=None, no_grad=False, **kwargs): else: name = (str(self.res[i]), str(self.res[j])) amp_tmp.set_used_res(fl(self.res[i]) + fl(self.res[j])) - int_tmp, g_int_tmp = eval_integral( + int_tmp, g_int_tmp, int_square = _eval_integral( amp_tmp, mcdata, var=self.var, @@ -102,6 +107,9 @@ def append_int(self, mcdata, *args, weight=None, no_grad=False, **kwargs): ) self.cached_int[name] = self.cached_int[name] + int_tmp self.cached_grad[name] = self.cached_grad[name] + g_int_tmp + self.cached_square[name] = ( + self.cached_square[name] + int_square + ) self.amp.set_used_res(cahced_res) @@ -142,7 +150,7 @@ def get_frac_grad(self, sum_diag=True): ) return fit_frac, g_fit_frac - def get_frac(self, error_matrix=None, sum_diag=True): + def get_frac(self, error_matrix=None, sum_diag=True, add_int_error=False): if error_matrix is None: error_matrix = self.error_matrix fit_frac, g_fit_frac = self.get_frac_grad(sum_diag=sum_diag) @@ -152,6 +160,9 @@ def get_frac(self, error_matrix=None, sum_diag=True): for k, v in g_fit_frac.items(): e = np.sqrt(np.dot(np.dot(error_matrix, v), v)) fit_frac_err[k] = e + if add_int_error and k in self.cached_square: + scale = 1 / self.cached_int_total**2 + fit_frac_err[k] = np.sqrt(e**2 + self.cached_square[k] * scale) return fit_frac, fit_frac_err def __iter__(self): diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index b387f62..35f8ec6 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -341,7 +341,7 @@ def pw_f(x, **kwargs): fit_result.save_as("toy_data/final_params.json") fit_frac, frac_err = toy_config.cal_fitfractions() fit_frac, frac_err = toy_config.cal_fitfractions(method="new") - fit_frac, frac_err = toy_config.cal_fitfractions( + fit_frac_obj = toy_config.cal_fitfractions( method="new", res=["R_BC", ["R_BD", "R_CD"]] ) save_frac_csv("toy_data/fit_frac.csv", fit_frac)