Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix svdmodel_benchmark filter specification, add unit test #303

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nmma/em/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
model_parameters_dict = {
"Bu2019nsbh": ["log10_mej_dyn", "log10_mej_wind", "KNtheta"],
"Bu2019lm": ["log10_mej_dyn", "log10_mej_wind", "KNphi", "KNtheta"],
"Bu2019lm_sparse": ["log10_mej_dyn", "log10_mej_wind"],
"Ka2017": ["log10_mej", "log10_vej", "log10_Xlan"],
"TrPi2018": [
"inclination_EM",
Expand Down
106 changes: 69 additions & 37 deletions nmma/em/svdmodel_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,29 @@
)


def main():
def get_parser():

parser = argparse.ArgumentParser(
description="Surrogate model performance benchmark"
)
parser.add_argument(
"--model", type=str, required=True, help="Name of the SVD model created"
"--model",
type=str,
help="Name of the SVD model created",
required=True,
)
parser.add_argument(
"--svd-path",
type=str,
help="Path to the SVD directory, \
with {model}_mag.pkl, {model}_lbol.pkl or {model_tf.pkl}",
required=True,
)
parser.add_argument(
"--data-path",
type=str,
help="Path to the directory of light curve files",
required=True,
)
parser.add_argument(
"--data-file-type",
Expand All @@ -46,7 +51,7 @@ def main():
parser.add_argument(
"--interpolation-type",
type=str,
required=True,
default="tensorflow",
help="Type of interpolation performed",
)
parser.add_argument(
Expand Down Expand Up @@ -81,17 +86,21 @@ def main():
)
parser.add_argument(
"--filters",
nargs="+",
type=str,
help="A comma seperated list of filters to use (e.g. g,r,i). If none is provided, will use all the filters available",
help="A space-seperated list of filters to use (e.g. g r i). If none is provided, will use all the filters available",
)
parser.add_argument(
"--ncpus",
type=int,
default=4,
help="Number of CPU to be used (default: 4)",
default=1,
help="Number of CPU to be used (default: 1)",
)
parser.add_argument(
"--outdir", type=str, default="output", help="Path to the output directory"
"--outdir",
type=str,
default="benchmark_output",
help="Path to the output directory",
)
parser.add_argument(
"--ignore-bolometric",
Expand All @@ -105,60 +114,77 @@ def main():
default=False,
help="only look for local svdmodels (ignore Zenodo)",
)
args = parser.parse_args()

return parser


def create_benchmark(
model,
svd_path,
data_path,
data_file_type="bulla",
interpolation_type="tensorflow",
data_time_unit="days",
svd_ncoeff=10,
tmin=0.0,
tmax=14.0,
dt=0.1,
filters=None,
ncpus=1,
outdir="benchmark_output",
ignore_bolometric=True,
local_only=False,
):

# make the outdir
if not os.path.isdir(args.outdir):
os.makedirs(args.outdir)
if not os.path.isdir(outdir):
os.makedirs(outdir)

# get the grid data file path
file_extensions = ["dat", "csv", "dat.gz", "h5"]
filenames = []
for file_extension in file_extensions:
if not args.ignore_bolometric:
filenames = filenames + glob.glob(f"{args.data_path}/*.{file_extension}")
if not ignore_bolometric:
filenames = filenames + glob.glob(f"{data_path}/*.{file_extension}")
else:
filenames = filenames + glob.glob(
f"{args.data_path}/*[!_Lbol].{file_extension}"
)
filenames = filenames + glob.glob(f"{data_path}/*[!_Lbol].{file_extension}")
if len(filenames) == 0:
raise ValueError("Need at least one file to interpolate.")

# read the grid data
grid_data = read_photometry_files(filenames, datatype=args.data_file_type)
grid_data = read_photometry_files(filenames, datatype=data_file_type)

# create the SVD training data
MODEL_FUNCTIONS = {
k: v for k, v in model_parameters.__dict__.items() if inspect.isfunction(v)
}
if args.model not in list(MODEL_FUNCTIONS.keys()):
raise ValueError(
f"{args.model} unknown. Please add to nmma.em.model_parameters"
)
model_function = MODEL_FUNCTIONS[args.model]
if model not in list(MODEL_FUNCTIONS.keys()):
raise ValueError(f"{model} unknown. Please add to nmma.em.model_parameters")
model_function = MODEL_FUNCTIONS[model]
grid_training_data, parameters = model_function(grid_data)

# create the SVDlight curve model
sample_times = np.arange(args.tmin, args.tmax + args.dt, args.dt)
sample_times = np.arange(tmin, tmax + dt, dt)
light_curve_model = SVDLightCurveModel(
args.model,
model,
sample_times,
svd_path=args.svd_path,
mag_ncoeff=args.svd_ncoeff,
interpolation_type=args.interpolation_type,
local_only=args.local_only,
svd_path=svd_path,
mag_ncoeff=svd_ncoeff,
interpolation_type=interpolation_type,
filters=filters,
local_only=local_only,
)

# get the filts
if not args.filters:
if not filters:
first_entry_name = list(grid_training_data.keys())[0]
first_entry = grid_training_data[first_entry_name]
filts = first_entry.keys() - set(["t"] + parameters)
filts = list(filts)
else:
filts = args.filters
filts = filters

print(f"Benchmarking model {args.model} on filter {filts} with {args.ncpus} cpus")
print(f"Benchmarking model {model} on filter {filts} with {ncpus} cpus")

def chi2_func(grid_entry_name, data_time_unit="days"):
# fetch the grid data and parameter
Expand All @@ -180,10 +206,10 @@ def chi2_func(grid_entry_name, data_time_unit="days"):
)
grid_t = grid_t / time_scale_factor

used_grid_t = grid_t[(grid_t > args.tmin) * (grid_t < args.tmax)]
used_grid_t = grid_t[(grid_t > tmin) * (grid_t < tmax)]
grid_mAB = {}
for filt in filts:
time_indices = (grid_t > args.tmin) * (grid_t < args.tmax)
time_indices = (grid_t > tmin) * (grid_t < tmax)
grid_mAB_per_filt_array = np.array(grid_entry[filt])
grid_mAB[filt] = grid_mAB_per_filt_array[time_indices]
# fetch the grid parameters
Expand All @@ -200,17 +226,17 @@ def chi2_func(grid_entry_name, data_time_unit="days"):
return chi2

grid_entry_names = list(grid_training_data.keys())
if args.ncpus == 1:
if ncpus == 1:
chi2_dict_array = [
chi2_func(grid_entry_name, data_time_unit=args.data_time_unit)
chi2_func(grid_entry_name, data_time_unit=data_time_unit)
for grid_entry_name in grid_entry_names
]
else:
chi2_dict_array = p_map(
chi2_func,
grid_entry_names,
data_time_unit=args.data_time_unit,
num_cpus=args.ncpus,
data_time_unit=data_time_unit,
num_cpus=ncpus,
)

chi2_array_by_filt = {}
Expand All @@ -224,5 +250,11 @@ def chi2_func(grid_entry_name, data_time_unit="days"):
plt.ylabel("Count")
plt.hist(chi2_array_by_filt[filt], label=filt, bins=51, histtype="step")
plt.legend()
plt.savefig(f"{args.outdir}/{filt}.pdf", bbox_inches="tight")
plt.savefig(f"{outdir}/{filt}.pdf", bbox_inches="tight")
plt.close()


def main():
parser = get_parser()
args = parser.parse_args()
create_benchmark(**vars(args))
18 changes: 17 additions & 1 deletion nmma/tests/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import glob
import numpy as np

from ..em import training, model_parameters, io
from ..em import training, model_parameters, io, svdmodel_benchmark


def test_training():
Expand Down Expand Up @@ -44,6 +44,14 @@ def test_training():
interpolation_type=interpolation_type,
)

svdmodel_benchmark.create_benchmark(
model_name,
ModelPath,
dataDir,
interpolation_type=interpolation_type,
filters=filts,
)

interpolation_type = "tensorflow"
training.SVDTrainingModel(
model_name,
Expand All @@ -55,3 +63,11 @@ def test_training():
svd_path=ModelPath,
interpolation_type=interpolation_type,
)

svdmodel_benchmark.create_benchmark(
model_name,
ModelPath,
dataDir,
interpolation_type=interpolation_type,
filters=filts,
)
Loading