Skip to content

Commit

Permalink
Merge pull request #147 from jiangyi15/mlp
Browse files Browse the repository at this point in the history
MLP model
  • Loading branch information
jiangyi15 authored Jun 11, 2024
2 parents 8d419f5 + db93972 commit 0db7545
Show file tree
Hide file tree
Showing 10 changed files with 299 additions and 0 deletions.
26 changes: 26 additions & 0 deletions checks/fit_shape/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Script to directly fit the shape with MLP

# input:

- config.yml: target

- init_params.json: parameters for target model

- config_mi.yml: fit model

# fit script

- fit_shape.py

# output:

- fit_results.png: plot of real and image

- final_params.json: results of fit parameters

# example

- final*params*\*.json: one possible fit results with activation name in the
tail.

The best model have loss about 31.85.
22 changes: 22 additions & 0 deletions checks/fit_shape/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
data:
dat_order: [B, C, D]

decay:
A: [BC, D]
BC: [B, C]

particle:
$top: A
$finals: [B, C, D]
A: { J: 0, P: -1, mass: 3.0 }
B: { J: 0, P: -1, mass: 0.1 }
C: { J: 0, P: -1, mass: 0.1 }
D: { J: 0, P: -1, mass: 0.1 }
BC: [BC1, BC2]
BC1: { J: 0, P: +1, mass: 1.0, width: 0.5 }
BC2: { J: 0, P: +1, mass: 2.0, width: 0.5 }

constrains:
decay:
fix_chain_idx: 0
fix_chain_val: 1
27 changes: 27 additions & 0 deletions checks/fit_shape/config_mi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
data:
dat_order: [B, C, D]

decay:
A: [BC, D]
BC: [B, C]

particle:
$top: A
$finals: [B, C, D]
A: { J: 0, P: -1, mass: 3.0 }
B: { J: 0, P: -1, mass: 0.1 }
C: { J: 0, P: -1, mass: 0.1 }
D: { J: 0, P: -1, mass: 0.1 }
BC: [MI]
MI:
J: 0
P: +1
mass: 1.0
interp_N: 10
model: MLP
activation: leaky_relu

constrains:
decay:
fix_chain_idx: 0
fix_chain_val: 1
38 changes: 38 additions & 0 deletions checks/fit_shape/final_params_leaky_relu.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"MI_b_0": -0.8796771549166532,
"MI_b_1": -0.5555867175456078,
"MI_b_2": 0.25869784992909006,
"MI_b_3": -1.0627308851937134,
"MI_b_4": 3.6615951678704897,
"MI_b_5": -0.07954771940190528,
"MI_b_6": -0.256124862341781,
"MI_b_7": 14.23127083846775,
"MI_b_8": -1.286108632752361,
"MI_b_9": 0.07076624632726991,
"MI_w_0r": -5.104381116132341,
"MI_w_0i": 7.871238527765366,
"MI_w_1r": 3.23365301131465,
"MI_w_1i": 2.5985262831448064,
"MI_w_2r": -5.405269642657437,
"MI_w_2i": 5.314416098804127,
"MI_w_3r": 6.633822670781981,
"MI_w_3i": 6.899102757998181,
"MI_w_4r": -2.241794106880599,
"MI_w_4i": 4.787308576161889,
"MI_w_5r": 11.011051058950653,
"MI_w_5i": -0.7205598197812919,
"MI_w_6r": 7.060488845223941,
"MI_w_6i": 7.225803210685805,
"MI_w_7r": 0.4388703291081927,
"MI_w_7i": -1.2006659389428078,
"MI_w_8r": 3.8879823702434892,
"MI_w_8i": 2.574115900343803,
"MI_w_9r": 10.533495767231539,
"MI_w_9i": -2.4450624359388207,
"A->MI.DMI->B.C_total_0r": 1.0,
"A->MI.DMI->B.C_total_0i": 0.0,
"A->MI.D_g_ls_0r": 1.0,
"A->MI.D_g_ls_0i": 0.0,
"MI->B.C_g_ls_0r": 1.0,
"MI->B.C_g_ls_0i": 0.0
}
38 changes: 38 additions & 0 deletions checks/fit_shape/final_params_relu.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"MI_b_0": -1.0622250950337662,
"MI_b_1": -0.8789283255768763,
"MI_b_2": 3.0599096017088203,
"MI_b_3": -1.285576319474175,
"MI_b_4": -0.255100410428076,
"MI_b_5": -0.5535368004801267,
"MI_b_6": 7.558983658698586,
"MI_b_7": 0.07110322552772569,
"MI_b_8": 0.2590192741152075,
"MI_b_9": -0.07886432386304608,
"MI_w_0r": 5.3129601453575965,
"MI_w_0i": 6.89330400075596,
"MI_w_1r": 4.076189414869148,
"MI_w_1i": 4.72169072366533,
"MI_w_2r": 2.2345575394656794,
"MI_w_2i": 1.5674140496474243,
"MI_w_3r": 3.1181489763446524,
"MI_w_3i": 2.5713070284633615,
"MI_w_4r": 5.664131964810804,
"MI_w_4i": 7.217921257894127,
"MI_w_5r": -2.5927805257104555,
"MI_w_5i": 5.731281787333332,
"MI_w_6r": 0.6947219631515521,
"MI_w_6i": -1.3812766994545103,
"MI_w_7r": -8.411867876040107,
"MI_w_7i": 6.974832219369253,
"MI_w_8r": 4.313161954304356,
"MI_w_8i": 2.171071429023437,
"MI_w_9r": 8.808426082864889,
"MI_w_9i": 5.556405458272451,
"A->MI.DMI->B.C_total_0r": 1.0,
"A->MI.DMI->B.C_total_0i": 0.0,
"A->MI.D_g_ls_0r": 1.0,
"A->MI.D_g_ls_0i": 0.0,
"MI->B.C_g_ls_0r": 1.0,
"MI->B.C_g_ls_0i": 0.0
}
69 changes: 69 additions & 0 deletions checks/fit_shape/fit_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tf_pwa.config_loader import ConfigLoader
from tf_pwa.utils import time_print

config = ConfigLoader("config.yml")

config.set_params("init_params.json")

config_mi = ConfigLoader("config_mi.yml")
config_mi.get_params()

# config_mi.set_params("final_params.json")

config_mi.vm.set_fix("MI_w_0r", unfix=True)
config_mi.vm.set_fix("MI_w_0i", unfix=True)


f1 = config.get_particle_function("BC1")
f2 = config.get_particle_function("BC2")

f = config_mi.get_particle_function("MI")

m = f.mass_linspace(10000)


fast_f = f.cached_call(m)
target_f = f1(m) + f2(m)


plot_count = 1


def f_loss():
ret = tf.reduce_sum(tf.abs(fast_f() - target_f) ** 2)
global plot_count
if plot_count % 10 == 1:
print(ret)
plot_count += 1
return ret


best_params = {}
best_loss = np.inf
best_fit_result = None
for i in range(1):
fit_result = time_print(config_mi.vm.minimize)(f_loss)
if fit_result.fun < best_loss:
best_loss = fit_result.fun
best_params = config_mi.get_params()
best_fit_result = fit_result
# reset random parameters
config_mi2 = ConfigLoader("config_mi.yml")
config_mi.set_params(config_mi2.get_params())

config_mi.set_params(best_params)
config_mi.save_params("final_params.json")

print(best_fit_result)

plt.plot(m, tf.math.imag(f(m)).numpy(), label="image fit")
plt.plot(m, tf.math.imag(f1(m) + f2(m)).numpy(), label="imag target")
plt.plot(m, tf.math.real(f(m)).numpy(), label="real fit")
plt.plot(m, tf.math.real(f1(m) + f2(m)).numpy(), label="real target")

plt.legend()
plt.savefig("fit_results.png")
18 changes: 18 additions & 0 deletions checks/fit_shape/init_params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"BC1_mass": 1.0,
"BC1_width": 0.5,
"BC2_mass": 2.0,
"BC2_width": 0.5,
"A->BC1.DBC1->B.C_total_0r": 1.0,
"A->BC1.DBC1->B.C_total_0i": 0.0,
"A->BC1.D_g_ls_0r": 1.0,
"A->BC1.D_g_ls_0i": 0.0,
"BC1->B.C_g_ls_0r": 1.0,
"BC1->B.C_g_ls_0i": 0.0,
"A->BC2.DBC2->B.C_total_0r": 1.0,
"A->BC2.DBC2->B.C_total_0i": 1.5,
"A->BC2.D_g_ls_0r": 1.0,
"A->BC2.D_g_ls_0i": 0.0,
"BC2->B.C_g_ls_0r": 1.0,
"BC2->B.C_g_ls_0i": 0.0
}
48 changes: 48 additions & 0 deletions tf_pwa/amp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,54 @@ def get_amp(self, data, _data_c=None, **kwargs):
return tf.math.polyval(pi, mass)


@regist_particle("MLP")
class ParticleMLP(Particle):
"""
Multilayer Perceptron like model.
.. math::
R(m) = \\sum_{k} w_k activation(m-m_0+b_k)
lineshape when `interp_N: 11`, `activation: relu`, :math:`b_k=(k-5)/10`, :math:`w_k = exp(k i\\pi/2)`
.. plot::
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> plt.clf()
>>> from tf_pwa.utils import plot_particle_model
>>> plot_params = {f"R_BC_b_{i}": (i-5)/10 for i in range(11)}
>>> plot_params.update({f"R_BC_w_{i}r": 1 for i in range(11)})
>>> plot_params.update({f"R_BC_w_{i}i": i * np.pi/2 for i in range(11)})
>>> axis = plot_particle_model("MLP", params={"interp_N": 11, "activation": "relu"}, plot_params=plot_params)
"""

activation_function = {
"relu2": lambda x: tf.nn.relu(x) ** 2,
"relu3": lambda x: tf.nn.relu(x) ** 3,
}

def init_params(self):
self.interp_N = getattr(self, "interp_N", 3)
self.activation = getattr(self, "activation", "leaky_relu")
self.activation_f = ParticleMLP.activation_function.get(
self.activation, getattr(tf.nn, self.activation)
)
self.bi = self.add_var("b", shape=(self.interp_N,))
self.wi = self.add_var("w", shape=(self.interp_N,), is_complex=True)
self.wi.set_fix_idx(fix_idx=0, fix_vals=(1.0, 0.0))

def get_amp(self, data, _data_c=None, **kwargs):
mass = data["m"] - self.get_mass()
bi = tf.stack(self.bi())
wi = tf.stack(self.wi())
x = tf.expand_dims(mass, axis=-1) + bi
x = self.activation_f(x)
ret = tf.reduce_sum(wi * tf.complex(x, tf.zeros_like(x)), axis=-1)
return ret


@regist_decay("particle-decay")
class ParticleDecay(HelicityDecay):
def get_ls_amp(self, data, data_p, **kwargs):
Expand Down
11 changes: 11 additions & 0 deletions tf_pwa/config_loader/particle_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ def __call__(self, m, random=False):
ret = a[self.idx]
return self.norm_factor * ret

def cached_call(self, m, **kwargs):
p = self.ha.generate_p_mass(self.name, m, **kwargs)
data = self.config.data.cal_angle(p)

def f():
a = build_amp.build_params_vector(self.decay_group, data)
ret = a[self.idx]
return self.norm_factor * ret

return f

def mass_range(self):
return self.ha.get_mass_range(self.name)

Expand Down
2 changes: 2 additions & 0 deletions tf_pwa/config_loader/tests/test_particle_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def test_particle(toy_config):
assert f(m).shape == (5, 6)
g = toy_config.get_particle_function("R_BC", d_norm=True)
g.phsp_fractor(m)
cached_g = g.cached_call(m)
assert np.allclose(g(m).numpy(), cached_g().numpy())
g.density(m)
assert g(m).shape == (5, 6)
m_min, m_max = g.mass_range()
Expand Down

0 comments on commit 0db7545

Please sign in to comment.