Skip to content

Commit

Permalink
refactored systematics code
Browse files Browse the repository at this point in the history
  • Loading branch information
sahiljhawar committed Nov 2, 2024
1 parent 05dc895 commit 082d277
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 53 deletions.
1 change: 1 addition & 0 deletions nmma/em/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import systematics


def from_list(self, systematics):
"""
Similar to `from_file` but instead of file buffer, takes a list of Prior strings
Expand Down
113 changes: 62 additions & 51 deletions nmma/em/systematics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import yaml
from pathlib import Path
import inspect
import warnings
from pathlib import Path

import yaml
from bilby.core.prior import analytical

warnings.simplefilter("module", DeprecationWarning)

Expand Down Expand Up @@ -39,26 +42,38 @@ def __init__(self, key, message):
"ztfr",
]

ALLOWED_DISTRIBUTIONS = dict(inspect.getmembers(analytical, inspect.isclass))


def get_positional_args(cls):
init_method = cls.__init__

signature = inspect.signature(init_method)
params = [
param.name
for param in signature.parameters.values()
if param.name != "self" and param.default == inspect.Parameter.empty
]

return params


DISTRIBUTION_PARAMETERS = {k: get_positional_args(v) for k, v in ALLOWED_DISTRIBUTIONS.items()}


def load_yaml(file_path):
return yaml.safe_load(Path(file_path).read_text())


def validate_only_one_true(yaml_dict):
for key, values in yaml_dict["config"].items():
if "value" not in values or type(values["value"]) is not bool:
raise ValidationError(
key, "'value' key must be present and be a boolean"
)
if "value" not in values or not isinstance(values["value"], bool):
raise ValidationError(key, "'value' key must be present and be a boolean")
true_count = sum(value["value"] for value in yaml_dict["config"].values())
if true_count > 1:
raise ValidationError(
"config", "Only one configuration key can be set to True at a time"
)
raise ValidationError("config", "Only one configuration key can be set to True at a time")
elif true_count == 0:
raise ValidationError(
"config", "At least one configuration key must be set to True"
)
raise ValidationError("config", "At least one configuration key must be set to True")


def validate_filters(filter_groups):
Expand Down Expand Up @@ -100,67 +115,63 @@ def validate_filters(filter_groups):


def validate_distribution(distribution):
if distribution != "Uniform":
dist_type = distribution.get("type")
if dist_type not in ALLOWED_DISTRIBUTIONS:
raise ValidationError(
"type",
f"Invalid distribution '{distribution}'. Only 'Uniform' distribution is supported",
"distribution type",
f"Invalid distribution '{dist_type}'. Allowed values are {', '.join([str(f) for f in ALLOWED_DISTRIBUTIONS])}",
)

required_params = DISTRIBUTION_PARAMETERS[dist_type]

def validate_fields(key, values, required_fields):
missing_fields = [
field for field in required_fields if values.get(field) is None
]
if missing_fields:
missing_params = set(required_params) - set(distribution.keys())
if missing_params:
raise ValidationError(
key, f"Missing fields: {', '.join(missing_fields)}"
"distribution", f"Missing required parameters for {dist_type} distribution: {', '.join(missing_params)}"
)
for field, expected_type in required_fields.items():
if not isinstance(values[field], expected_type):
raise ValidationError(
key, f"'{field}' must be of type {expected_type}"
)


def handle_withTime(key, values):
required_fields = {
"type": str,
"min": (float, int),
"max": (float, int),
"time_nodes": int,
"filters": list,
}
def create_prior_string(name, distribution):
dist_type = distribution.pop("type")
_ = distribution.pop("value")
_ = distribution.pop("time_nodes", None)
_ = distribution.pop("filters", None)
prior_class = ALLOWED_DISTRIBUTIONS[dist_type]
required_params = DISTRIBUTION_PARAMETERS[dist_type]
params = distribution.copy()

extra_params = set(params.keys()) - set(required_params)
if extra_params:
warnings.warn(f"Distribution parameters {extra_params} are not used by {dist_type} distribution and will be ignored")

params = {k: params[k] for k in required_params if k in params}

return f"{name} = {repr(prior_class(**params, name=name))}"

validate_fields(key, values, required_fields)

def handle_withTime(values):
validate_distribution(values)
filter_groups = values.get("filters", [])
validate_filters(filter_groups)
distribution = values.get("type")
validate_distribution(distribution)
result = []
time_nodes = values["time_nodes"]

for filter_group in filter_groups:
if isinstance(filter_group, list):
filter_name = "___".join(filter_group)
else:
filter_name = filter_group if filter_group is not None else "all"

for n in range(1, values["time_nodes"] + 1):
result.append(
f'sys_err_{filter_name}{n} = {values["type"]}(minimum={values["min"]},maximum={values["max"]},name="sys_err_{filter_name}{n}")'
)
for n in range(1, time_nodes + 1):
prior_name = f"sys_err_{filter_name}{n}"
result.append(create_prior_string(prior_name, values.copy()))

return result


def handle_withoutTime(key, values):
required_fields = {"type": str, "min": (float, int), "max": (float, int)}
validate_fields(key, values, required_fields)
distribution = values.get("type")
validate_distribution(distribution)
return [
f'sys_err = {values["type"]}(minimum={values["min"]},maximum={values["max"]},name="sys_err")'
]
def handle_withoutTime(values):
validate_distribution(values)
return [create_prior_string("sys_err", values)]


config_handlers = {
Expand All @@ -175,5 +186,5 @@ def main(yaml_file_path):
results = []
for key, values in yaml_dict["config"].items():
if values["value"] and key in config_handlers:
results.extend(config_handlers[key](key, values))
return results
results.extend(config_handlers[key](values))
return results
4 changes: 2 additions & 2 deletions nmma/em/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ def calc_lc(
y_pred, sigma2_pred = gp.predict(
np.atleast_2d(param_list_postprocess), return_std=True
)
cAproj[i] = y_pred
cAstd[i] = sigma2_pred
cAproj[i] = np.squeeze(y_pred)
cAstd[i] = np.squeeze(sigma2_pred)

# coverrors = np.dot(VA[:, :n_coeff], np.dot(np.power(np.diag(cAstd[:n_coeff]), 2), VA[:, :n_coeff].T))
# errors = np.diag(coverrors)
Expand Down

0 comments on commit 082d277

Please sign in to comment.