-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added 1) Uniform sampling test and 2) basic SBC test #64
base: master
Are you sure you want to change the base?
added 1) Uniform sampling test and 2) basic SBC test #64
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a quick word to say that I haven't forgotten about this :)
while any(is_chain_short): | ||
num_samples *= 2 | ||
print(f"Running sampler for {num_samples} more iterations") | ||
posterior = sampler.run(num_samples=num_samples, compile=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do posterior += sampler.run()
if you want to add new samples starting from the last state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah that would be simpler!
|
||
def effective_sample_size(chain): | ||
num_samples = len(chain) | ||
tau_chain = max(integrated_time(chain, tol=100, quiet=True), 1, ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be added to the codebase soon, once the core refactor is merged.
return ess, int(tau_chain) | ||
|
||
def rank_statistic(post_samples, prior_sample): | ||
list_bools = [sam < prior_sample for sam in post_samples] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these arrays? In this case post_samples < prior_samples
should give the result directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point!
rank_statistic_dict = {name: [] for name in prior_names} | ||
|
||
starttime = time.time() | ||
for i in range(num_replicas): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's batch this! Do you know exactly what prevents it from being batched?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As each iteration build the model and the sampler (and some of the objects such as the distributions have user defined classes) I imagined that compiling all this in JAX wouldn't work (though I haven't tried it!).
Also, I think you've mentioned before that the building of the likelihood and sampler can be simplified a lot, which means it'll be faster and that we'll be able to compile/batch it. I'm not sure how to do this though
f8f3e6b
to
965f6dd
Compare
Now that we're moving samplers to blackjax, I think tests based on SBC would be better there. It will also be a lot simpler to implement than in MCX. |
Here are some changes to get the ball rolling with MCMC tests. There's still a lot left to do though.
I include a notebook to look at the empirical distributions of the test outputs. Though obviously that will need to be automated at some point.
Though it might be fun at some point to have a similar notebook where a user (they would probably be an mcmc-nerd..) could interact with the outputs of the tests themselves. That way they could convince themselves that the sampler works correctly.
Uniform sampling
This is similar to the
hmc_test.py
except that it gets a lot more samples. I set a simple model with 1 parameter where the prior is uniform and the data doesn't depend on the prior. That way the posterior is also uniform. I use HMC to get 500K samples and in the notebook I plot the empirical CDF along with the line f(y)=x.This is a very simple test to do and has saved me from a few bugs in the past!
Perhaps this could be removed and the
hmc_test.py
could simply run for longer (to make sure that the variance is correct).SBC
This is a first implementation of SBC. It will work for any model with scalar parameters. Problems/things to think about:
integrated_time
function to get the IAT and used that to determine whether the chain should be run for longer. For some generated datasets the sampler needs to be run for longer. It might be helpful to generate less data. That way the IAT will (probably) always be below some desired threshold. The test would then still check the IAT, but could perhaps discard that sample