You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, **kwargs)
165 def init_fn(*args, **kwargs) -> hk.MutableParams:
--> 166 params, state = f.init(*args, **kwargs)
167 if state:
168 raise base.NonEmptyStateError(
169 "If your transformed function uses hk.{get,set}_state then use "
170 "hk.transform_with_state.")
File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, **kwargs)
420 with base.new_context(rng=rng) as ctx:
421 try:
--> 422 f(*args, **kwargs)
423 except jax.errors.UnexpectedTracerError as e:
424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nn/make_flow.py:137, in _make_maf.._flow(method, **kwargs)
132 raise ValueError(
133 f"n_dimension at layer {i} is layer than the dimension of"
134 f" the following layer {i + 1}"
135 )
136 layers.append(layer)
--> 137 layers.append(Permutation(order, 1))
138 chain = Chain(layers[:-1])
140 base_distribution = distrax.Independent(
141 distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)),
142 1,
143 )
TypeError: Can't instantiate abstract class Permutation without an implementation for abstract method 'forward_and_log_det'.
I guess the error is from Permutation function in Surjector library. When I checked the Permutation function, it does not include method 'forward_and_log_det'. Only '_forward_and_likelihood_contribution' and '_inverse_and_likelihood_contribution' are in Permutation. So I added two functions: 'forward_and_log_det' and 'inverse_and_log_det' in Permutation below:
class Permutation(distrax.Bijector):
"""Permute the dimensions of a vector.
Args:
permutation: a vector of integer indexes representing the order of
the elements
event_ndims_in: number of input event dimensions
Examples:
>>> from surjectors import Permutation
>>> from jax import numpy as jnp
>>>
>>> order = jnp.arange(10)
>>> perm = Permutation(order, 1)
"""
def __init__(self, permutation, event_ndims_in: int):
super().__init__(event_ndims_in)
self.permutation = permutation
def forward_and_log_det(self, x):
# Forward transformation and log determinant calculation
z, log_det = self._forward_and_likelihood_contribution(x)
return z, log_det
def inverse_and_log_det(self, y):
# Inverse transformation and log determinant calculation
z, log_det = self._inverse_and_likelihood_contribution(y)
return z, log_det
def _forward_and_likelihood_contribution(self, z):
return z[..., self.permutation], jnp.full(jnp.shape(z)[:-1], 0.0)
def _inverse_and_likelihood_contribution(self, y):
size = self.permutation.size
permutation_inv = (
jnp.zeros(size, dtype=jnp.result_type(int))
.at[self.permutation]
.set(jnp.arange(size))
)
return y[..., permutation_inv], jnp.full(jnp.shape(y)[:-1], 0.0)
Right now, with the new permutation function, the use of make_maf works. Could you please check whether my implementation is right or wrong?
Thanks!
The text was updated successfully, but these errors were encountered:
Thanks for reporting. Ill fix this as soon as my time allows. As I was saying in the other thread, we did a major refactor where I likely introduced bugs.
Hi Simon,
When I ran the following codes:
n_dim_data = 2
n_layers, hidden_sizes = 5, (64, 64)
neural_network = make_maf(n_dim_data,n_layers=5, n_layer_dimensions=[2, 2, 2, 2, 2],hidden_sizes=hidden_sizes)
fns = prior_fn, simulator_fn
model = NLE(fns, neural_network)
obs = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(0), n_simulations=10_000)
params, losses = model.fit(
jr.PRNGKey(1), data=data
)
inference_results, diagnostics = model.sample_posterior(
jr.PRNGKey(2), params, obs
)
It appears an error:
TypeError Traceback (most recent call last)
Cell In[108], line 3
1 obs = jnp.array([-1.0, 1.0])
2 data, _ = model.simulate_data(jr.PRNGKey(0), n_simulations=10_000)
----> 3 params, losses = model.fit(
4 jr.PRNGKey(1), data=data
5 )
6 inference_results, diagnostics = model.sample_posterior(
7 jr.PRNGKey(2), params, obs
8 )
File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:87, in NLE.fit(self, rng_key, data, optimizer, n_iter, batch_size, percentage_data_as_validation_set, n_early_stopping_patience, **kwargs)
83 itr_key, rng_key = jr.split(rng_key)
84 train_iter, val_iter = self.as_iterators(
85 itr_key, data, batch_size, percentage_data_as_validation_set
86 )
---> 87 params, losses = self._fit_model_single_round(
88 seed=rng_key,
89 train_iter=train_iter,
90 val_iter=val_iter,
91 optimizer=optimizer,
92 n_iter=n_iter,
93 n_early_stopping_patience=n_early_stopping_patience,
94 )
96 return params, losses
File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:109, in NLE._fit_model_single_round(self, seed, train_iter, val_iter, optimizer, n_iter, n_early_stopping_patience)
99 def _fit_model_single_round(
100 self,
101 seed,
(...)
106 n_early_stopping_patience,
107 ):
108 init_key, seed = jr.split(seed)
--> 109 params = self._init_params(init_key, **next(iter(train_iter)))
110 state = optimizer.init(params)
112 @jax.jit
113 def step(params, state, **batch):
File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:176, in NLE._init_params(self, rng_key, **init_data)
175 def _init_params(self, rng_key, **init_data):
--> 176 params = self.model.init(
177 rng_key, method="log_prob", y=init_data["y"], x=init_data["theta"]
178 )
179 return params
File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, **kwargs)
165 def init_fn(*args, **kwargs) -> hk.MutableParams:
--> 166 params, state = f.init(*args, **kwargs)
167 if state:
168 raise base.NonEmptyStateError(
169 "If your transformed function uses
hk.{get,set}_state
then use "170 "
hk.transform_with_state
.")File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, **kwargs)
420 with base.new_context(rng=rng) as ctx:
421 try:
--> 422 f(*args, **kwargs)
423 except jax.errors.UnexpectedTracerError as e:
424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nn/make_flow.py:137, in _make_maf.._flow(method, **kwargs)
132 raise ValueError(
133 f"n_dimension at layer {i} is layer than the dimension of"
134 f" the following layer {i + 1}"
135 )
136 layers.append(layer)
--> 137 layers.append(Permutation(order, 1))
138 chain = Chain(layers[:-1])
140 base_distribution = distrax.Independent(
141 distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)),
142 1,
143 )
File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/distrax/_src/utils/jittable.py:32, in Jittable.new(failed resolving arguments)
30 except ValueError:
31 registered_cls = cls # Already registered.
---> 32 return object.new(registered_cls)
TypeError: Can't instantiate abstract class Permutation without an implementation for abstract method 'forward_and_log_det'.
I guess the error is from Permutation function in Surjector library. When I checked the Permutation function, it does not include method 'forward_and_log_det'. Only '_forward_and_likelihood_contribution' and '_inverse_and_likelihood_contribution' are in Permutation. So I added two functions: 'forward_and_log_det' and 'inverse_and_log_det' in Permutation below:
class Permutation(distrax.Bijector):
"""Permute the dimensions of a vector.
Right now, with the new permutation function, the use of make_maf works. Could you please check whether my implementation is right or wrong?
Thanks!
The text was updated successfully, but these errors were encountered: