From Pytorch-Pyro's website:
We’re excited to announce the release of NumPyro, a NumPy-backed Pyro using JAX for automatic differentiation and JIT compilation, with over 100x speedup for HMC and NUTS!
My questions:
- Where is the performance gain (which is sometimes 340x or 2X) of NumPyro (over Pyro) coming from exactly?
- And more importantly, why (rather, where) would I continue to use Pyro?
Extra:
- How should I view the performance and features of NumPyro compared to Tensorflow Probability, in deciding which to use where?
That's a good question. I just asked the same question in Pyro's dedicated forum. Here's the answer of one of their core developers: "There are many cool stuffs in Pyro that do not appear in NumPyro, for example, see Contributed code section in Pyro docs. For me, while developing, it is much easier to debug PyTorch code than Jax code (though Jax team has put much effort to help debugging in recent releases). Hence to implement a new inference algorithm, it is easier for me to work in Pyro."