You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As I am looking through the code, I think it might be possible to abstract out a lot of commonality between the Pyro and NumPyro backends. NumPyro has a plate handler now, and the distribution broadcasting semantics should be the same.
One part which might be slightly different is torch/numpy operations, but we only need a few of them and can figure out a generic dispatch mechanism for that. Are there other parts that are significantly different between the two backends?
The text was updated successfully, but these errors were encountered:
Indeed! The code generation bits for numpyro started out as a copy and paste of the pyro variant, and hasn't diverged very much. fwiw, one approach I have in mind is to add an extra stage to the "compilation" pipeline, that takes a ModelDesc (a full description of the model, assembled from the formula, data, priors, etc.) and produces something like an AST that describes in a generic way how the model ought to be implemented in terms of probabilistic primitives (sample, observe, etc.) and linear algebra operations. I think this would capture a lot of the overlap we currently have. The individual back ends would then only need to walk this AST turning sample in to e.g. pyro.sample, mat-vec-mul in to e.g. torch.mv, etc.
On the inference side, the two back ends we currently have are now more similar than I suspect they once were (because of Pyro interface changes I think) though differences remain. I think the differences between the two implementations of prior are fairly typical of the differences in general. The pyro variant collects samples by tracing model execution multiple times (not vectorized) and fetches everything it needs (to implement e.g. get_param) from the trace. The numpyro variant traces the model in a vectorized way, but this doesn't provide access to the return value of the model, so the model is re-run with substitute to get hold of those. (IIRC).
As I am looking through the code, I think it might be possible to abstract out a lot of commonality between the Pyro and NumPyro backends. NumPyro has a
plate
handler now, and the distribution broadcasting semantics should be the same.One part which might be slightly different is torch/numpy operations, but we only need a few of them and can figure out a generic dispatch mechanism for that. Are there other parts that are significantly different between the two backends?
The text was updated successfully, but these errors were encountered: