From f86de4cdedfaebe4ede80c90fbfc32b287c90c91 Mon Sep 17 00:00:00 2001 From: haukekoehn Date: Mon, 18 Nov 2024 12:08:46 +0100 Subject: [PATCH] Added Gaussian prior and add a new write-to-file function in utils. --- src/fiesta/inference/fiesta.py | 2 -- src/fiesta/inference/prior.py | 49 ++++++++++++++++++++++++++++++++++ src/fiesta/utils.py | 13 +++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/src/fiesta/inference/fiesta.py b/src/fiesta/inference/fiesta.py index 078dc81..443bd18 100644 --- a/src/fiesta/inference/fiesta.py +++ b/src/fiesta/inference/fiesta.py @@ -18,8 +18,6 @@ from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.utils.PRNG_keys import initialize_rng_keys -import time # TODO: remove me! - default_hyperparameters = { "seed": 0, "n_chains": 20, diff --git a/src/fiesta/inference/prior.py b/src/fiesta/inference/prior.py index 220c49c..d6da5c8 100644 --- a/src/fiesta/inference/prior.py +++ b/src/fiesta/inference/prior.py @@ -145,6 +145,55 @@ def log_prob(self, x: dict[str, Array]) -> Float: jnp.zeros_like(variable), ) return output + jnp.log(1.0 / (self.xmax - self.xmin)) + + +@jaxtyped(typechecker=typechecker) +class Normal(Prior): + mu: float = 0.0 + sigma: float = 1.0 + + def __repr__(self): + return f"Normal(mu={self.mu}, sigma={self.sigma})" + + def __init__( + self, + mu: Float, + sigma: Float, + naming: list[str], + transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, + ): + super().__init__(naming, transforms) + assert self.n_dim == 1, "Normal needs to be 1D distributions" + self.mu = mu + self.sigma = sigma + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + """ + Sample from a normal distribution. + + Parameters + ---------- + rng_key : PRNGKeyArray + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + samples = jax.random.normal(rng_key, (n_samples,),) + samples = self.mu + self.sigma * samples + return self.add_name(samples[None]) + + def log_prob(self, x: dict[str, Array]) -> Float: + variable = x[self.naming[0]] + return -1/(2*self.sigma**2) * (variable-self.mu)**2 - jnp.sqrt(2*jnp.pi*self.sigma**2) # class DiracDelta(Prior): diff --git a/src/fiesta/utils.py b/src/fiesta/utils.py index 8cafad3..cdba2ff 100644 --- a/src/fiesta/utils.py +++ b/src/fiesta/utils.py @@ -222,6 +222,19 @@ def load_event_data(filename): return data +def write_event_data(filename: str, data: dict): + """ + Takes a magnitude dict and writes it to filename. + The magnitude dict should have filters as keys, the arrays should have the structure [[mjd, mag, err]]. + """ + with open(filename, "w") as out: + for filt in data.keys(): + for data_point in data[filt]: + time = Time(data_point[0], format = "mjd") + filt_name = filt.replace("_", ":") + line = f"{time.isot} {filt_name} {data_point[1]:f} {data_point[2]:f}" + out.write(line +"\n") + ######################### ### Filters ### #########################