PyTorch implementation of the implicit reparametrisation trick for mixture distributions based on Figurnov et al., 2019, "Implicit Reparameterization Gradients" and the implementation in Tensorflow Probability.
Can be readily used for variational inference with mixture distribution variational families.
Remarks:
- For multivariate mixtures, the class is currently implemented when the mixture component distributions fully factorise.
- Also added a
StableNormal
distribution, which overrides the defaultcdf
method with a more stable implementation from pytorch/pytorch#52973 (comment). The implementation also provides a_log_cdf
method, however it is not used for the implicit reparametrisation.