Any way to "vmap" or "plate" in pymc #7175
Unanswered
justindomke
asked this question in
Q&A
Replies: 1 comment 6 replies
-
Hi, the issue is not the graph vectorization, but how PyMC tries to register variables in a model context. That's the same reason you can't put PyMC distributions inside a Scan function. Instead you need to first convert a PyMC model to a PyTensor graph, operate on that graph and then convert back to a PyMC model. We have been adding some functionality that allows you to represent a PyMC model a pure PyTensor graph for some time, we use that for I explored some vectorization model transformations sometime ago here: https://gist.github.com/ricardoV94/99c53fbb8b2e9a68e1b2c6c4d761eaf4 |
Beta Was this translation helpful? Give feedback.
6 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello all,
I was wondering, does pymc support any construct similar to jax.vmap? Consider this simple model, in which I've used vectorized distributions:
I was hoping for some kind of function like
pymc.vmap
that would allow me to do something like this instead:Of course, for this simple example the original vectorized syntax is much clearer. However, when things get complicated, manual vectorization can get quite tricky, so I was hoping for something like this.
I've noticed that pytensor and aesara now have
scan
, which I guess is sort of more general than vmap, though I haven't been able to figure out if scan can be used directly when defining pymc models, or it's more something that should be used internally.Beta Was this translation helpful? Give feedback.
All reactions