Skip to content

Commit

Permalink
Add option to repeat simulations
Browse files Browse the repository at this point in the history
  • Loading branch information
mcoughlin committed Aug 25, 2023
1 parent bd64f62 commit 49a4863
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions nmma/eos/create_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pandas as pd

import numpy as np
import scipy.interpolate

import lalsimulation as lalsim
from gwpy.table import Table
Expand Down Expand Up @@ -267,7 +266,13 @@ def get_parser():
action="store_true",
help="Whether to only sample prior parameters in injection file",
)

parser.add(
"-r",
"--repeated-simulations",
default=0,
type=int,
help="Number of repeated simulations, fixing other parameters (default: 0)",
)
return parser


Expand All @@ -292,10 +297,10 @@ def main(args=None):

if not args.original_parameters:
# load the EOS
radius_val, mass_val, Lambda_val= np.loadtxt(
radius_val, mass_val, Lambda_val = np.loadtxt(
args.eos_file, usecols=[0, 1, 2], unpack=True
)

# load the injection json file
if args.injection_file:
if args.injection_file.endswith(".json"):
Expand Down Expand Up @@ -338,6 +343,20 @@ def main(args=None):
generation_seed=args.generation_seed,
)
dataframe_from_prior = injection_creator.get_injection_dataframe()
if args.repeated_simulations > 0:
repeats = []
timeshifts = []
injection_creator.n_injection = args.repeated_simulations
for index, row in dataframe_from_prior.iterrows():
timeshift_frame = injection_creator.get_injection_dataframe()
for ii in range(args.repeated_simulations):
timeshifts.append(timeshift_frame["KNtimeshift"][ii])
repeats.append(row)
dataframe_from_prior = pd.concat(repeats, axis=1).transpose().reset_index()
dataframe_from_prior.drop(
labels=["index", "KNtimeshift"], axis="columns", inplace=True
)
dataframe_from_prior["KNtimeshift"] = timeshifts

inj_columns = set(dataframe_from_inj.columns.tolist())
prior_columns = set(dataframe_from_prior.columns.tolist())
Expand Down Expand Up @@ -385,8 +404,8 @@ def main(args=None):

for injIdx in range(0, Ninj):
mMax, rMax, lam1, lam2, r1, r2, R_14, R_16 = EOS2Parameters(
mass_val,
radius_val,
mass_val,
radius_val,
Lambda_val,
dataframe["mass_1_source"][injIdx],
dataframe["mass_2_source"][injIdx],
Expand Down

0 comments on commit 49a4863

Please sign in to comment.