Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in Permutation when using make_maf #41

Open
Jice-Zeng opened this issue Aug 10, 2024 · 1 comment
Open

Error in Permutation when using make_maf #41

Jice-Zeng opened this issue Aug 10, 2024 · 1 comment

Comments

@Jice-Zeng
Copy link

Hi Simon,
When I ran the following codes:
n_dim_data = 2
n_layers, hidden_sizes = 5, (64, 64)
neural_network = make_maf(n_dim_data,n_layers=5, n_layer_dimensions=[2, 2, 2, 2, 2],hidden_sizes=hidden_sizes)
fns = prior_fn, simulator_fn
model = NLE(fns, neural_network)
obs = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(0), n_simulations=10_000)
params, losses = model.fit(
jr.PRNGKey(1), data=data
)
inference_results, diagnostics = model.sample_posterior(
jr.PRNGKey(2), params, obs
)

It appears an error:
TypeError Traceback (most recent call last)
Cell In[108], line 3
1 obs = jnp.array([-1.0, 1.0])
2 data, _ = model.simulate_data(jr.PRNGKey(0), n_simulations=10_000)
----> 3 params, losses = model.fit(
4 jr.PRNGKey(1), data=data
5 )
6 inference_results, diagnostics = model.sample_posterior(
7 jr.PRNGKey(2), params, obs
8 )

File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:87, in NLE.fit(self, rng_key, data, optimizer, n_iter, batch_size, percentage_data_as_validation_set, n_early_stopping_patience, **kwargs)
83 itr_key, rng_key = jr.split(rng_key)
84 train_iter, val_iter = self.as_iterators(
85 itr_key, data, batch_size, percentage_data_as_validation_set
86 )
---> 87 params, losses = self._fit_model_single_round(
88 seed=rng_key,
89 train_iter=train_iter,
90 val_iter=val_iter,
91 optimizer=optimizer,
92 n_iter=n_iter,
93 n_early_stopping_patience=n_early_stopping_patience,
94 )
96 return params, losses

File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:109, in NLE._fit_model_single_round(self, seed, train_iter, val_iter, optimizer, n_iter, n_early_stopping_patience)
99 def _fit_model_single_round(
100 self,
101 seed,
(...)
106 n_early_stopping_patience,
107 ):
108 init_key, seed = jr.split(seed)
--> 109 params = self._init_params(init_key, **next(iter(train_iter)))
110 state = optimizer.init(params)
112 @jax.jit
113 def step(params, state, **batch):

File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:176, in NLE._init_params(self, rng_key, **init_data)
175 def _init_params(self, rng_key, **init_data):
--> 176 params = self.model.init(
177 rng_key, method="log_prob", y=init_data["y"], x=init_data["theta"]
178 )
179 return params

File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, **kwargs)
165 def init_fn(*args, **kwargs) -> hk.MutableParams:
--> 166 params, state = f.init(*args, **kwargs)
167 if state:
168 raise base.NonEmptyStateError(
169 "If your transformed function uses hk.{get,set}_state then use "
170 "hk.transform_with_state.")

File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, **kwargs)
420 with base.new_context(rng=rng) as ctx:
421 try:
--> 422 f(*args, **kwargs)
423 except jax.errors.UnexpectedTracerError as e:
424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nn/make_flow.py:137, in _make_maf.._flow(method, **kwargs)
132 raise ValueError(
133 f"n_dimension at layer {i} is layer than the dimension of"
134 f" the following layer {i + 1}"
135 )
136 layers.append(layer)
--> 137 layers.append(Permutation(order, 1))
138 chain = Chain(layers[:-1])
140 base_distribution = distrax.Independent(
141 distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)),
142 1,
143 )

File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/distrax/_src/utils/jittable.py:32, in Jittable.new(failed resolving arguments)
30 except ValueError:
31 registered_cls = cls # Already registered.
---> 32 return object.new(registered_cls)

TypeError: Can't instantiate abstract class Permutation without an implementation for abstract method 'forward_and_log_det'.

I guess the error is from Permutation function in Surjector library. When I checked the Permutation function, it does not include method 'forward_and_log_det'. Only '_forward_and_likelihood_contribution' and '_inverse_and_likelihood_contribution' are in Permutation. So I added two functions: 'forward_and_log_det' and 'inverse_and_log_det' in Permutation below:
class Permutation(distrax.Bijector):
"""Permute the dimensions of a vector.

Args:
    permutation: a vector of integer indexes representing the order of
        the elements
    event_ndims_in: number of input event dimensions

Examples:
    >>> from surjectors import Permutation
    >>> from jax import numpy as jnp
    >>>
    >>> order = jnp.arange(10)
    >>> perm = Permutation(order, 1)
"""

def __init__(self, permutation, event_ndims_in: int):
    super().__init__(event_ndims_in)
    self.permutation = permutation
    
def forward_and_log_det(self, x):
    # Forward transformation and log determinant calculation
    z, log_det = self._forward_and_likelihood_contribution(x)
    return z, log_det

def inverse_and_log_det(self, y):
    # Inverse transformation and log determinant calculation
    z, log_det = self._inverse_and_likelihood_contribution(y)
    return z, log_det
    
def _forward_and_likelihood_contribution(self, z):
    return z[..., self.permutation], jnp.full(jnp.shape(z)[:-1], 0.0)

def _inverse_and_likelihood_contribution(self, y):
    size = self.permutation.size
    permutation_inv = (
        jnp.zeros(size, dtype=jnp.result_type(int))
        .at[self.permutation]
        .set(jnp.arange(size))
    )
    return y[..., permutation_inv], jnp.full(jnp.shape(y)[:-1], 0.0)

Right now, with the new permutation function, the use of make_maf works. Could you please check whether my implementation is right or wrong?
Thanks!

@dirmeier
Copy link
Owner

dirmeier commented Aug 15, 2024

Thanks for reporting. Ill fix this as soon as my time allows. As I was saying in the other thread, we did a major refactor where I likely introduced bugs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants