Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better summary plot for lightcurve-analysis #365

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 91 additions & 49 deletions nmma/em/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,7 @@ def analysis(args):
if args.plot:
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable

if len(models) > 1:
_, mag_all = light_curve_model.generate_lightcurve(
Expand Down Expand Up @@ -1003,77 +1004,137 @@ def analysis(args):
colors = cm.Spectral(np.linspace(0, 1, len(filters_plot)))[::-1]

plotName = os.path.join(args.outdir, f"{args.label}_lightcurves.png")
plt.figure(figsize=(20, 16))
color2 = "coral"

# set up the geometry for the all-in-one figure
wspace = 0.6 # All in inches.
hspace = 0.3
lspace = 1.0
bspace = 0.7
trspace = 0.2
hpanel = 2.25
wpanel = 3.

ncol = 2
nrow = int(np.ceil(len(filters_plot) / ncol))
fig, axes = plt.subplots(nrow, ncol)

figsize = (1.5 * (lspace + wpanel * ncol + wspace * (ncol - 1) + trspace),
1.5 * (bspace + hpanel * nrow + hspace * (nrow - 1) + trspace))
# Create the figure and axes.
fig, axes = plt.subplots(nrow, ncol, figsize=figsize, squeeze=False)
fig.subplots_adjust(left=lspace / figsize[0],
bottom=bspace / figsize[1],
right=1. - trspace / figsize[0],
top=1. - trspace / figsize[1],
wspace=wspace / wpanel,
hspace=hspace / hpanel)

if len(filters_plot) % 2:
axes[-1, -1].axis('off')

cnt = 0
for filt, color in zip(filters_plot, colors):
cnt = cnt + 1
if cnt == 1:
ax1 = plt.subplot(len(filters_plot), 1, cnt)

# summary plot
row = (cnt - 1) // ncol
col = (cnt - 1) % ncol
ax_sum = axes[row, col]
# adding the ax for the Delta
divider = make_axes_locatable(ax_sum)
ax_delta = divider.append_axes('bottom',
size='30%',
sharex=ax_sum)

# configuring ax_sum
ax_sum.set_ylabel("AB magnitude", rotation=90)
ax_delta.set_ylabel(r"$\Delta (\sigma)$")
if cnt == len(filters_plot) or cnt == len(filters_plot) - 1:
ax_delta.set_xlabel("Time [days]")
else:
ax2 = plt.subplot(len(filters_plot), 1, cnt, sharex=ax1, sharey=ax1)
ax_delta.set_xticklabels([])

# plotting the best-fit lc and the data in ax1
samples = data[filt]
t, y, sigma_y = samples[:, 0], samples[:, 1], samples[:, 2]
t -= trigger_time + timeshift
idx = np.where(~np.isnan(y))[0]
t, y, sigma_y = t[idx], y[idx], sigma_y[idx]

idx = np.where(np.isfinite(sigma_y))[0]
plt.errorbar(
det_idx = idx
ax_sum.errorbar(
t[idx],
y[idx],
sigma_y[idx],
fmt="o",
color="k",
markersize=16,
) # or color=color
color=color,
)

idx = np.where(~np.isfinite(sigma_y))[0]
plt.plot(
t[idx], y[idx], marker="v", color="k", markersize=16
) # or color=color

ax_sum.errorbar(
t[idx],
y[idx],
sigma_y[idx],
fmt="v",
color=color,
)

mag_plot = getFilteredMag(mag, filt)

plt.plot(
# calculating the chi2
mag_per_data = np.interp(
t[det_idx],
mag["bestfit_sample_times"],
mag_plot)
diff_per_data = mag_per_data - y[det_idx]
sigma_per_data = np.sqrt((sigma_y[det_idx]**2 + error_budget[filt]**2))
chi2_per_data = diff_per_data**2
chi2_per_data /= sigma_per_data**2
chi2_total = np.sum(chi2_per_data)
N_data = len(det_idx)

# plot the mismatch between the model and the data
ax_delta.scatter(t[det_idx], diff_per_data / sigma_per_data, color=color)
ax_delta.axhline(0, linestyle='--', color='k')

ax_sum.plot(
mag["bestfit_sample_times"],
mag_plot,
color=color2,
color='coral',
linewidth=3,
linestyle="--",
)

if len(models) > 1:
plt.fill_between(
ax_sum.fill_between(
mag["bestfit_sample_times"],
mag_plot + error_budget[filt],
mag_plot - error_budget[filt],
facecolor=color2,
facecolor='coral',
alpha=0.2,
label="Combined",
label="combined",
)
else:
plt.fill_between(
ax_sum.fill_between(
mag["bestfit_sample_times"],
mag_plot + error_budget[filt],
mag_plot - error_budget[filt],
facecolor=color2,
facecolor='coral',
alpha=0.2,
)

if len(models) > 1:
for ii in range(len(mag_all)):
mag_plot = getFilteredMag(mag_all[ii], filt)
plt.plot(
ax_sum.plot(
mag["bestfit_sample_times"],
mag_plot,
color=color2,
color='coral',
linewidth=3,
linestyle="--",
)
plt.fill_between(
ax_sum.fill_between(
mag["bestfit_sample_times"],
mag_plot + error_budget[filt],
mag_plot - error_budget[filt],
Expand All @@ -1082,32 +1143,13 @@ def analysis(args):
label=models[ii].model,
)

plt.ylabel("%s" % filt, fontsize=48, rotation=0, labelpad=40)

plt.xlim([float(x) for x in args.xlim.split(",")])
plt.ylim([float(x) for x in args.ylim.split(",")])
plt.grid()

if cnt == 1:
ax1.set_yticks([26, 22, 18, 14])
plt.setp(ax1.get_xticklabels(), visible=False)
if len(models) > 1:
plt.legend(
loc="upper right",
prop={"size": 18},
numpoints=1,
shadow=True,
fancybox=True,
)
elif not cnt == len(filters_plot):
plt.setp(ax2.get_xticklabels(), visible=False)
plt.xticks(fontsize=36)
plt.yticks(fontsize=36)

ax1.set_zorder(1)
plt.xlabel("Time [days]", fontsize=48)
plt.tight_layout()
plt.savefig(plotName)
ax_sum.set_title(f'{filt}: ' + fr'$\chi^2 / d.o.f. = {round(chi2_total / N_data, 2)}$')

ax_sum.set_xlim([float(x) for x in args.xlim.split(",")])
ax_sum.set_ylim([float(x) for x in args.ylim.split(",")])
ax_delta.set_xlim([float(x) for x in args.xlim.split(",")])

plt.savefig(plotName, bbox_inches='tight')
plt.close()


Expand Down
Loading