diff --git a/tf_pwa/config_loader/config_loader.py b/tf_pwa/config_loader/config_loader.py index fdae923..1c14f4f 100644 --- a/tf_pwa/config_loader/config_loader.py +++ b/tf_pwa/config_loader/config_loader.py @@ -35,6 +35,7 @@ ) from tf_pwa.cal_angle import prepare_data_from_decay from tf_pwa.data import ( + ReadData, data_index, data_merge, data_shape, @@ -1225,6 +1226,10 @@ def __init__(self, plot_config, decay_struct): self.params.append(i) for i in self.get_angle_vars(True): self.params.append(i) + for i in self.get_index_vars(): + self.params.append(i) + for i in self.get_extra_vars(): + self.params.append(i) def get_data_index(self, sub, name): dec = self.decay_struct.topology_structure() @@ -1268,16 +1273,39 @@ def get_data_index(self, sub, name): self.re_map.get(p, p), "aligned_angle", ) + if sub == "index": + name_i = name.split("/") + return name_i raise ValueError("unknown sub {}".format(sub)) + def read_plot_config(self, v): + upper_ylim = v.get("upper_ylim", None) + xrange = v.get("range", None) + units = v.get("units", "") + bins = v.get("bins", self.defaults_config.get("bins", 50)) + legend = v.get("legend", self.defaults_config.get("legend", True)) + legend_outside = v.get( + "legend_outside", + self.defaults_config.get("legend_outside", False), + ) + yscale = v.get("yscale", self.defaults_config.get("yscale", "linear")) + upper_ylim = v.get("upper_ylim", None) + return { + "upper_ylim": upper_ylim, + "legend": legend, + "legend_outside": legend_outside, + "range": xrange, + "bins": bins, + "units": units, + "yscale": yscale, + } + def get_mass_vars(self): mass = self.config.get("mass", {}) x = sy.symbols("x") for k, v in mass.items(): id_ = v.get("id", k) display = v.get("display", "M({})".format(k)) - upper_ylim = v.get("upper_ylim", None) - xrange = v.get("range", None) trans = v.get("trans", None) if trans is None: trans = lambda x: x @@ -1285,31 +1313,20 @@ def get_mass_vars(self): trans = sy.sympify(trans) trans = sy.lambdify(x, trans, modules="numpy") units = v.get("units", "GeV") - bins = v.get("bins", self.defaults_config.get("bins", 50)) - legend = v.get("legend", self.defaults_config.get("legend", True)) - legend_outside = v.get( - "legend_outside", - self.defaults_config.get("legend_outside", False), - ) - yscale = v.get( - "yscale", self.defaults_config.get("yscale", "linear") + common_config = self.read_plot_config(v) + idx = ( + "particle", + self.re_map.get(get_particle(id_), get_particle(id_)), + "m", ) yield { + **common_config, + "units": units, "name": "m_" + k, "display": display, - "upper_ylim": upper_ylim, - "idx": ( - "particle", - self.re_map.get(get_particle(id_), get_particle(id_)), - "m", - ), - "legend": legend, - "legend_outside": legend_outside, - "range": xrange, - "bins": bins, + "idx": idx, "trans": trans, - "units": units, - "yscale": yscale, + "readdata": ReadData(idx, trans), } def get_angle_vars(self, is_align=False): @@ -1347,24 +1364,12 @@ def get_angle_vars(self, is_align=False): ) for j, v in i.items(): display = v.get("display", j) - upper_ylim = v.get("upper_ylim", None) theta = j trans = lambda x: x if "cos" in j: theta = j[4:-1] trans = np.cos - bins = v.get("bins", self.defaults_config.get("bins", 50)) - xrange = v.get("range", None) - legend = v.get( - "legend", self.defaults_config.get("legend", False) - ) - legend_outside = v.get( - "legend_outside", - self.defaults_config.get("legend_outside", False), - ) - yscale = v.get( - "yscale", self.defaults_config.get("yscale", "linear") - ) + common_config = self.read_plot_config(v) if is_align: ang_type = "aligned_angle" else: @@ -1372,26 +1377,78 @@ def get_angle_vars(self, is_align=False): name_id = validate_file_name(k + "_" + j) if is_align: name_id = "aligned_" + name_id + idx = ("decay", decay_chain, decay, part, ang_type, theta) yield { + **common_config, "name": name_id, "display": display, - "upper_ylim": upper_ylim, - "idx": ( - "decay", - decay_chain, - decay, - part, - ang_type, - theta, - ), + "idx": idx, "trans": trans, - "bins": bins, - "range": xrange, - "legend": legend, - "legend_outside": legend_outside, - "yscale": yscale, + "readdata": ReadData(idx, trans), } + def get_extra_vars(self): + + from tf_pwa.formula import build_expr_function + + dic = self.config.get("extra_vars", {}) + + for k, v in dic.items(): + expr = v["expr"] + where = v.get("where", {}) + f_expr, used_var = build_expr_function(expr) + + var_f = [] + for i in used_var: + idx = where.get(i, i) + if isinstance(idx, (list, tuple)): + idx = idx[0], "/".join(idx[1]) + var_f = ReadData(self.get_data_index(*idx)) + elif isinstance(idx, str): + for j in self.params: + if j["name"] == idx: + var_f.append(j["readdata"]) + else: + raise TypeError("unknown variables for trans ") + + def readdata(x): + var = [data_to_numpy(i(x)) for i in var_f] + return f_expr(**dict(zip(used_var, var))) + + id_ = v.get("id", k) + display = v.get("display", str(expr)) + common_config = self.read_plot_config(v) + yield { + **common_config, + "name": k, + "display": display, + "readdata": readdata, + } + + def get_index_vars(self): + + dic = self.config.get("index", {}) + + for k, v in dic.items(): + idx = self.get_data_index("index", k) + id_ = v.get("id", k) + display = v.get("display", str(k)) + trans = v.get("trans", None) + if trans is None: + trans = lambda x: x + else: + trans = sy.sympify(trans) + x = sy.symbols("x") + trans = sy.lambdify(x, trans, modules="numpy") + common_config = self.read_plot_config(v) + readdata = ReadData(idx, trans) + yield { + **common_config, + "name": k, + "display": display, + "readdata": readdata, + } + def get_params(self, params=None): if params is None: return self.params diff --git a/tf_pwa/config_loader/multi_config.py b/tf_pwa/config_loader/multi_config.py index b5ca1b6..ed12af1 100644 --- a/tf_pwa/config_loader/multi_config.py +++ b/tf_pwa/config_loader/multi_config.py @@ -240,8 +240,9 @@ def save_params(self, file_name): with open(file_name, "w") as f: json.dump(val, f, indent=2) - def plot_partial_wave(self, params=None, prefix="figure/all", **kwargs): - + def _get_plot_partial_wave_input( + self, params=None, prefix="figure/all", save_root=False, **kwargs + ): path = os.path.dirname(prefix) os.makedirs(path, exist_ok=True) @@ -276,7 +277,28 @@ def plot_partial_wave(self, params=None, prefix="figure/all", **kwargs): print("com_plot: set", k, "to 0 for sample", idx) phsp[k] = np.zeros_like(phsp["MC_total_fit"]) phsp_dict = data_to_numpy(data_merge(*[i[1] for i in all_data])) + + if save_root: + save_dict_to_root( + [data_dict, phsp_dict, bg_dict], + file_name=prefix + "variables_com.root", + tree_name=["data", "fitted", "sideband"], + ) + print("Save root file " + prefix + "com_variables.root") + + return (data_dict, phsp_dict, bg_dict), extra + + def plot_partial_wave( + self, params=None, prefix="figure/all", save_root=False, **kwargs + ): + + data, extra = self._get_plot_partial_wave_input( + params=params, prefix=prefix, save_root=save_root, **kwargs + ) + + data_dict, phsp_dict, bg_dict = data _, plot_var_dic, chain_property, nll = extra + self.configs[-1]._plot_partial_wave( data_dict, phsp_dict, diff --git a/tf_pwa/config_loader/plot.py b/tf_pwa/config_loader/plot.py index dc212ae..b56df93 100644 --- a/tf_pwa/config_loader/plot.py +++ b/tf_pwa/config_loader/plot.py @@ -217,17 +217,22 @@ def create_chain_property(self, res): return chain_property -def create_plot_var_dic(plot_params): +def create_plot_var_dic(plot_params, extra_plots=None): + extra_plots = [] if extra_plots is None else extra_plots plot_var_dic = {} - for conf in plot_params.get_params(): + common_bins = None + for conf in plot_params.get_params() + extra_plots: name = conf.get("name") display = conf.get("display", name) upper_ylim = conf.get("upper_ylim", None) - idx = conf.get("idx") + idx = conf.get("idx", (None,)) trans = conf.get("trans", lambda x: x) + readdata = conf.get("readdata") has_legend = conf.get("legend", False) xrange = conf.get("range", None) - bins = conf.get("bins", None) + bins = conf.get("bins", common_bins) + if common_bins is None: + common_bins = bins legend_outside = conf.get("legend_outside", False) units = conf.get("units", "") yscale = conf.get("yscale", "linear") @@ -238,6 +243,7 @@ def create_plot_var_dic(plot_params): "legend_outside": legend_outside, "idx": idx, "trans": trans, + "readdata": readdata, "range": xrange, "bins": bins, "units": units, @@ -376,6 +382,7 @@ def _get_plot_partial_wave_input( chains_id_method=None, cut_function=lambda x: 1, partial_waves_function=None, + extra_plots=None, **kwargs ): """ @@ -443,7 +450,9 @@ def _get_plot_partial_wave_input( [i, "pw_{}".format(i), "partial waves {}".format(i), None] for i in range(100) ] - plot_var_dic = create_plot_var_dic(self.plot_params) + plot_var_dic = create_plot_var_dic( + self.plot_params, extra_plots=extra_plots + ) if self._Ngroup == 1: data_dict, phsp_dict, bg_dict = self._cal_partial_wave( @@ -660,12 +669,10 @@ def _cal_partial_wave( weight_i ) for name in plot_var_dic: - idx = plot_var_dic[name]["idx"] - trans = lambda x: np.reshape(plot_var_dic[name]["trans"](x), (-1,)) + readdata = plot_var_dic[name]["readdata"] + idx = plot_var_dic[name].get("idx", (None,)) - data_i = batch_call_numpy( - lambda x: trans(data_index(x, idx)), data, batch - ) + data_i = batch_call_numpy(readdata, data, batch) if idx[-1] == "m": tmp_idx = list(idx) tmp_idx[-1] = "p" @@ -682,15 +689,11 @@ def _cal_partial_wave( data_dict[name + "_PZ"] = p4[3] data_dict[name] = data_i # data variable - phsp_i = batch_call_numpy( - lambda x: trans(data_index(x, idx)), phsp_rec, batch - ) + phsp_i = batch_call_numpy(readdata, phsp_rec, batch) phsp_dict[name + "_MC"] = phsp_i # MC if bg is not None: - bg_i = batch_call_numpy( - lambda x: trans(data_index(x, idx)), bg, batch - ) + bg_i = batch_call_numpy(readdata, bg, batch) bg_dict[name + "_sideband"] = bg_i # sideband data_dict = data_to_numpy(data_dict) phsp_dict = data_to_numpy(phsp_dict) @@ -1391,7 +1394,7 @@ def plot_function_2dpull( normal = mpl.colors.Normalize(vmin=-max_weight, vmax=max_weight) im = mpl.cm.ScalarMappable(norm=normal, cmap=my_cmap) # ax.colorbar(im) - ax.get_figure().colorbar(im) + ax.get_figure().colorbar(im, ax=ax) ax.set_title( "$\\chi^2/Nbins={:.2f}/{}$".format( np.sum(np.abs(pulls) ** 2), len(bound) diff --git a/tf_pwa/config_loader/plotter.py b/tf_pwa/config_loader/plotter.py index ab02f0f..ae996e9 100644 --- a/tf_pwa/config_loader/plotter.py +++ b/tf_pwa/config_loader/plotter.py @@ -15,7 +15,7 @@ _get_cfit_eff_phsp, create_chain_property, ) -from tf_pwa.data import batch_call, data_index, data_shape +from tf_pwa.data import ReadData, batch_call, data_index, data_shape from tf_pwa.histogram import Hist1D, WeightedData logger = logging.getLogger(__file__) @@ -56,20 +56,6 @@ def get_histogram(self, var, partial=None, **kwargs): return Hist1D.histogram(value, weights=w * self.scale, **kwargs) -class ReadData: - def __init__(self, var, trans=None): - self.var = var - self.trans = (lambda x: x) if trans is None else trans - - def __call__(self, data): - value = data_index(data, self.var) - value = self.trans(value) - return value - - def __repr__(self): - return str(self.var) - - class PlotDataGroup: def __init__(self, datasets): self.datasets = datasets @@ -173,16 +159,14 @@ def get_all_frame(self): name = conf.get("name") display = conf.get("display", name) upper_ylim = conf.get("upper_ylim", None) - idx = conf.get("idx") - trans = conf.get("trans", lambda x: x) + readdata = conf.get("readdata") has_legend = conf.get("legend", False) xrange = conf.get("range", None) bins = conf.get("bins", None) units = conf.get("units", "") yscale = conf.get("yscale", "linear") ret[name] = Frame( - idx, - trans=trans, + readdata, name=name, display=display, x_range=xrange, diff --git a/tf_pwa/data.py b/tf_pwa/data.py index 8193844..121f3a6 100644 --- a/tf_pwa/data.py +++ b/tf_pwa/data.py @@ -735,3 +735,17 @@ def _check_nan(dat, head): return True return _check_nan(data, head_keys) + + +class ReadData: + def __init__(self, var, trans=None): + self.var = var + self.trans = (lambda x: x) if trans is None else trans + + def __call__(self, data): + value = data_index(data, self.var) + value = self.trans(value) + return value + + def __repr__(self): + return str(self.var) diff --git a/tf_pwa/formula.py b/tf_pwa/formula.py index 7eda5b7..ba1b9f8 100644 --- a/tf_pwa/formula.py +++ b/tf_pwa/formula.py @@ -1,3 +1,4 @@ +import numpy as np import sympy from tf_pwa.breit_wigner import get_bprime_coeff @@ -141,3 +142,25 @@ def create_numpy_function(f, var, val, x, modules="numpy"): f_val = _flatten(val) f = f.subs(dict(zip(f_var, f_val))) return sympy.lambdify(x, f, modules=modules) + + +def build_expr_function(expr): + + expr = sympy.simplify(expr) + + all_symbols = expr.free_symbols + all_symbols = tuple(all_symbols) + + used_var = [] + for i in all_symbols: + used_var.append(str(i)) + + custom_function = { + "float": lambda x: np.array(x).astype(np.float64), + "int": lambda x: np.array(x).astype(np.int32), + } + + f_expr = sympy.lambdify( + all_symbols, expr, modules=[custom_function, "numpy"] + ) + return f_expr, used_var diff --git a/tf_pwa/tests/config_cfit.yml b/tf_pwa/tests/config_cfit.yml index 29cff6c..212116a 100644 --- a/tf_pwa/tests/config_cfit.yml +++ b/tf_pwa/tests/config_cfit.yml @@ -51,6 +51,11 @@ plot: aligned_angle: R_BC/B: cos(beta): { display: "$cos(\\beta)$" } + index: + charge_conjugation: { display: "$c$" } + extra_vars: + m_BC2: + expr: m_R_BC + m_R_CD 2Dplot: dalitz_12: x: m_R_BC**2 diff --git a/tf_pwa/tests/test_full.py b/tf_pwa/tests/test_full.py index 7e37d7d..e64f49f 100644 --- a/tf_pwa/tests/test_full.py +++ b/tf_pwa/tests/test_full.py @@ -165,10 +165,15 @@ def test_cfit(gen_toy): linestyle_file="toy_data/a.yml", chains_id_method="res", ) + + def f(x): + return x.get_weight() + config.plot_partial_wave_interf( "R_BC", "R_BD", prefix="toy_data/figure/interf_", + extra_plots=[{"name": "weight", "readdata": f}], ) config.plot_partial_wave_interf( "R_BC", diff --git a/tf_pwa/variable.py b/tf_pwa/variable.py index 78de19c..5a0e477 100644 --- a/tf_pwa/variable.py +++ b/tf_pwa/variable.py @@ -426,8 +426,9 @@ def set_fix(self, name, value=None, unfix=False): else: if name in self.bnd_dic: value = self.bnd_dic[name].get_y2x(value) - self.variables[name].assign(value) - self.variables[name]._trainable = unfix + if name in self.variables: + self.variables[name].assign(value) + self.variables[name]._trainable = unfix if unfix: if name in self.trainable_vars: warnings.warn("{} has been freed already!".format(name))