Skip to content

Commit

Permalink
Merge pull request #158 from wwdws1/resolution_correct
Browse files Browse the repository at this point in the history
Correct missing parameter in resolution script
  • Loading branch information
jiangyi15 authored Nov 11, 2024
2 parents 716d3a8 + 63115b6 commit 68f3d58
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 14 deletions.
6 changes: 5 additions & 1 deletion checks/resolution/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,11 @@ def main():
# dat = ha.build_data(ms, costheta, phi)

p4, w = random_sample(
config, decay_chain, toy, smear_method=results.method
config,
decay_chain,
toy,
particle=results.particle,
smear_method=results.method,
)
w = toy.get_weight() * w
save_name = config.data.dic[name[:-4]]
Expand Down
60 changes: 47 additions & 13 deletions fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,29 @@ def fit(
return fit_result


def write_some_results(config, fit_result, save_root=False, cpu_plot=False):
def write_some_results(
config,
fit_result,
save_root=False,
cpu_plot=False,
plot_figure=True,
add_chi2=False,
):
# plot partial wave distribution
if cpu_plot:
if plot_figure and cpu_plot:
with tf.device("CPU"):
config.plot_partial_wave(
fit_result, plot_pull=True, save_root=save_root
fit_result,
plot_pull=True,
save_root=save_root,
add_chi2=add_chi2,
)
else:
elif plot_figure:
config.plot_partial_wave(
fit_result, plot_pull=True, save_root=save_root
fit_result, plot_pull=True, save_root=save_root, add_chi2=add_chi2
)
else:
print("No plot.")

# calculate fit fractions
phsp_noeff = config.get_phsp_noeff()
Expand All @@ -209,23 +221,35 @@ def write_some_results(config, fit_result, save_root=False, cpu_plot=False):


def write_some_results_combine(
config, fit_result, save_root=False, cpu_plot=False
config,
fit_result,
save_root=False,
cpu_plot=False,
plot_figure=True,
add_chi2=False,
):

from tf_pwa.applications import fit_fractions

for i, c in enumerate(config.configs):
if cpu_plot:
with tf.device("CPU"):
if plot_figure:
for i, c in enumerate(config.configs):
if cpu_plot:
with tf.device("CPU"):
c.plot_partial_wave(
fit_result,
prefix="figure/s{}_".format(i),
save_root=save_root,
add_chi2=add_chi2,
)
else:
c.plot_partial_wave(
fit_result,
prefix="figure/s{}_".format(i),
save_root=save_root,
add_chi2=add_chi2,
)
else:
c.plot_partial_wave(
fit_result, prefix="figure/s{}_".format(i), save_root=save_root
)
else:
print("No plot.")

for it, config_i in enumerate(config.configs):
print("########## fit fractions {}:".format(it))
Expand Down Expand Up @@ -279,6 +303,12 @@ def main():
parser.add_argument(
"--CPU-plot", action="store_true", default=False, dest="cpu_plot"
)
parser.add_argument(
"--no-plot", action="store_false", default=True, dest="plot_figure"
)
parser.add_argument(
"--add-chi2", action="store_true", default=False, dest="add_chi2"
)
parser.add_argument(
"-c",
"--config",
Expand Down Expand Up @@ -320,13 +350,17 @@ def main():
fit_result,
save_root=results.save_root,
cpu_plot=results.cpu_plot,
plot_figure=results.plot_figure,
add_chi2=results.add_chi2,
)
else:
write_some_results_combine(
config,
fit_result,
save_root=results.save_root,
cpu_plot=results.cpu_plot,
plot_figure=results.plot_figure,
add_chi2=results.add_chi2,
)


Expand Down

0 comments on commit 68f3d58

Please sign in to comment.