Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added 1) Uniform sampling test and 2) basic SBC test #64

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

jeremiecoullon
Copy link
Contributor

Here are some changes to get the ball rolling with MCMC tests. There's still a lot left to do though.
I include a notebook to look at the empirical distributions of the test outputs. Though obviously that will need to be automated at some point.

Though it might be fun at some point to have a similar notebook where a user (they would probably be an mcmc-nerd..) could interact with the outputs of the tests themselves. That way they could convince themselves that the sampler works correctly.

Uniform sampling

This is similar to the hmc_test.py except that it gets a lot more samples. I set a simple model with 1 parameter where the prior is uniform and the data doesn't depend on the prior. That way the posterior is also uniform. I use HMC to get 500K samples and in the notebook I plot the empirical CDF along with the line f(y)=x.
This is a very simple test to do and has saved me from a few bugs in the past!
Perhaps this could be removed and the hmc_test.py could simply run for longer (to make sure that the variance is correct).

SBC

This is a first implementation of SBC. It will work for any model with scalar parameters. Problems/things to think about:

  • Doesn't work for parameters that are arrays yet.
  • Builds the likelihood and samplers etc.. at every iteration so can't be batched. As the result the python loop is very slow (around 1h10min for 1K rank statistics
  • the MCMC chain length is variable as the test needs independent samples. I used emcee integrated_time function to get the IAT and used that to determine whether the chain should be run for longer. For some generated datasets the sampler needs to be run for longer. It might be helpful to generate less data. That way the IAT will (probably) always be below some desired threshold. The test would then still check the IAT, but could perhaps discard that sample

Copy link
Owner

@rlouf rlouf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick word to say that I haven't forgotten about this :)

while any(is_chain_short):
num_samples *= 2
print(f"Running sampler for {num_samples} more iterations")
posterior = sampler.run(num_samples=num_samples, compile=True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do posterior += sampler.run() if you want to add new samples starting from the last state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah that would be simpler!


def effective_sample_size(chain):
num_samples = len(chain)
tau_chain = max(integrated_time(chain, tol=100, quiet=True), 1, )
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be added to the codebase soon, once the core refactor is merged.

return ess, int(tau_chain)

def rank_statistic(post_samples, prior_sample):
list_bools = [sam < prior_sample for sam in post_samples]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these arrays? In this case post_samples < prior_samples should give the result directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

rank_statistic_dict = {name: [] for name in prior_names}

starttime = time.time()
for i in range(num_replicas):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's batch this! Do you know exactly what prevents it from being batched?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As each iteration build the model and the sampler (and some of the objects such as the distributions have user defined classes) I imagined that compiling all this in JAX wouldn't work (though I haven't tried it!).

Also, I think you've mentioned before that the building of the likelihood and sampler can be simplified a lot, which means it'll be faster and that we'll be able to compile/batch it. I'm not sure how to do this though

@rlouf rlouf force-pushed the master branch 3 times, most recently from f8f3e6b to 965f6dd Compare February 23, 2021 11:28
@rlouf
Copy link
Owner

rlouf commented Aug 2, 2021

Now that we're moving samplers to blackjax, I think tests based on SBC would be better there. It will also be a lot simpler to implement than in MCX.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants