Skip to content

Commit

Permalink
Merge pull request #2 from haukekoehn/grb_filters
Browse files Browse the repository at this point in the history
Pass filter to GRB model
  • Loading branch information
haukekoehn authored Oct 9, 2023
2 parents 8beb6c5 + 6558afd commit 82bce42
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
13 changes: 5 additions & 8 deletions nmma/em/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,16 @@ def __init__(
self.svd_lbol_model = None
elif self.interpolation_type == "tensorflow":
import tensorflow as tf
tf.get_logger().setLevel("ERROR")
from tensorflow.keras.models import load_model

# TODO: remove below 3 lines once <model>_tf.pkl files on Zenodo are updated to <model>.pkl
if not os.path.exists(modelfile):
warnings.warn(
f"Attempting to load {core_model_name}_tf.pkl. In the future, all model files will have the format <model>.pkl, regardless of --interpolation-type."
)
modelfile = os.path.join(self.svd_path, f"{core_model_name}_tf.pkl")

tf.get_logger().setLevel("ERROR")
from tensorflow.keras.models import load_model


if not local_only:
_, model_filters = get_model(
self.svd_path, f"{self.model}_tf", filters=filters
Expand Down Expand Up @@ -369,7 +368,7 @@ def observation_angle_conversion(self, parameters):
def generate_lightcurve(self, sample_times, parameters):
if self.parameter_conversion:
new_parameters = parameters.copy()
new_parameters, _ = self.parameter_conversion(new_parameters, [])
new_parameters, _ = self.parameter_conversion(new_parameters)
else:
new_parameters = parameters.copy()

Expand Down Expand Up @@ -481,10 +480,9 @@ def __repr__(self):
return self.__class__.__name__ + "(model={0})".format(self.model)

def generate_lightcurve(self, sample_times, parameters):

if self.parameter_conversion:
new_parameters = parameters.copy()
new_parameters, _ = self.parameter_conversion(new_parameters, [])
new_parameters, _ = self.parameter_conversion(new_parameters)
else:
new_parameters = parameters.copy()

Expand Down Expand Up @@ -571,7 +569,6 @@ def observation_angle_conversion(self, parameters):
return parameters

def generate_lightcurve(self, sample_times, parameters):

total_lbol = np.zeros(len(sample_times))
total_mag = {}

Expand Down
5 changes: 3 additions & 2 deletions nmma/em/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,10 @@ def fluxDensity(t, nu, **params):
def grb_lc(t_day, Ebv, param_dict, filters=None):
day = 86400.0 # in seconds
tStart = (np.amin(t_day)) * day
tStart = max(10**(-5), tStart)
tStart = max(10**(-5)*day, tStart)
tEnd = (np.amax(t_day) + 1) * day
tnode = min(len(t_day), 201)
default_time = np.logspace(np.log10(tStart), np.log10(tEnd), base=10.0, num=tnode-1)
default_time = np.logspace(np.log10(tStart), np.log10(tEnd), base=10.0, num=tnode)
filts, lambdas = get_default_filts_lambdas(filters=filters)

nu_0s = scipy.constants.c / lambdas
Expand All @@ -571,6 +571,7 @@ def grb_lc(t_day, Ebv, param_dict, filters=None):
# output flux density is in milliJansky
try:
mJys = fluxDensity(times, nus, **param_dict)

except TimeoutError:
return t_day, np.zeros(t_day.shape), {}

Expand Down
18 changes: 10 additions & 8 deletions nmma/joint/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from ..em.model import SVDLightCurveModel, KilonovaGRBLightCurveModel
from ..em.model import SVDLightCurveModel, KilonovaGRBLightCurveModel, GRBLightCurveModel, GenericCombineLightCurveModel
from ..em.likelihood import OpticalLightCurve
from .conversion import MultimessengerConversion, MultimessengerConversionWithLambdas

Expand Down Expand Up @@ -124,7 +124,6 @@ def __init__(self, interferometers, waveform_generator,
time_marginalization=False, distance_marginalization=False,
phase_marginalization=False, distance_marginalization_lookup_table=None,
jitter_time=True, reference_frame="sky", time_reference="geocenter"):

# construct the eos prior
if with_eos:
xx = np.arange(0, Neos + 1)
Expand Down Expand Up @@ -166,20 +165,23 @@ def __init__(self, interferometers, waveform_generator,
GWLikelihood = ROQGravitationalWaveTransient(**gw_likelihood_kwargs)

# initialize the EM likelihood
if not filters:
filters = list(light_curve_data.keys())
sample_times = np.arange(tmin, tmax, 0.1)
light_curve_model_kwargs = dict(model=light_curve_model_name, sample_times=sample_times,
svd_path=light_curve_SVD_path,
parameter_conversion=parameter_conversion,
mag_ncoeff=mag_ncoeff, lbol_ncoeff=lbol_ncoeff,
interpolation_type=light_curve_interpolation_type)
interpolation_type=light_curve_interpolation_type, filters=filters)

if with_grb:
light_curve_model = KilonovaGRBLightCurveModel(sample_times=sample_times,
kilonova_kwargs=light_curve_model_kwargs,
GRB_resolution=grb_resolution)
models = []
models.append(SVDLightCurveModel(**light_curve_model_kwargs))
models.append(GRBLightCurveModel(sample_times = sample_times, resolution = grb_resolution, filters = filters, parameter_conversion = parameter_conversion))
light_curve_model = GenericCombineLightCurveModel(models = models, sample_times=sample_times)
else:
light_curve_model = SVDLightCurveModel(**light_curve_model_kwargs)
if not filters:
filters = list(light_curve_data.keys())

em_likelihood_kwargs = dict(light_curve_model=light_curve_model, filters=filters,
light_curve_data=light_curve_data, trigger_time=em_trigger_time,
error_budget=error_budget, tmin=tmin, tmax=tmax)
Expand Down

0 comments on commit 82bce42

Please sign in to comment.