Skip to content

Commit

Permalink
Added Gaussian prior and add a new write-to-file function in utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
haukekoehn committed Nov 18, 2024
1 parent a49e0db commit f86de4c
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
2 changes: 0 additions & 2 deletions src/fiesta/inference/fiesta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.utils.PRNG_keys import initialize_rng_keys

import time # TODO: remove me!

default_hyperparameters = {
"seed": 0,
"n_chains": 20,
Expand Down
49 changes: 49 additions & 0 deletions src/fiesta/inference/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,55 @@ def log_prob(self, x: dict[str, Array]) -> Float:
jnp.zeros_like(variable),
)
return output + jnp.log(1.0 / (self.xmax - self.xmin))


@jaxtyped(typechecker=typechecker)
class Normal(Prior):
mu: float = 0.0
sigma: float = 1.0

def __repr__(self):
return f"Normal(mu={self.mu}, sigma={self.sigma})"

def __init__(
self,
mu: Float,
sigma: Float,
naming: list[str],
transforms: dict[str, tuple[str, Callable]] = {},
**kwargs,
):
super().__init__(naming, transforms)
assert self.n_dim == 1, "Normal needs to be 1D distributions"
self.mu = mu
self.sigma = sigma

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from a normal distribution.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
samples = jax.random.normal(rng_key, (n_samples,),)
samples = self.mu + self.sigma * samples
return self.add_name(samples[None])

def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
return -1/(2*self.sigma**2) * (variable-self.mu)**2 - jnp.sqrt(2*jnp.pi*self.sigma**2)

# class DiracDelta(Prior):

Expand Down
13 changes: 13 additions & 0 deletions src/fiesta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,19 @@ def load_event_data(filename):

return data

def write_event_data(filename: str, data: dict):
"""
Takes a magnitude dict and writes it to filename.
The magnitude dict should have filters as keys, the arrays should have the structure [[mjd, mag, err]].
"""
with open(filename, "w") as out:
for filt in data.keys():
for data_point in data[filt]:
time = Time(data_point[0], format = "mjd")
filt_name = filt.replace("_", ":")
line = f"{time.isot} {filt_name} {data_point[1]:f} {data_point[2]:f}"
out.write(line +"\n")

#########################
### Filters ###
#########################
Expand Down

0 comments on commit f86de4c

Please sign in to comment.