Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify Pyro and NumPyro code generation #24

Open
neerajprad opened this issue Oct 4, 2019 · 1 comment
Open

Unify Pyro and NumPyro code generation #24

neerajprad opened this issue Oct 4, 2019 · 1 comment
Labels
discussion Requires broader discussion

Comments

@neerajprad
Copy link
Member

As I am looking through the code, I think it might be possible to abstract out a lot of commonality between the Pyro and NumPyro backends. NumPyro has a plate handler now, and the distribution broadcasting semantics should be the same.

One part which might be slightly different is torch/numpy operations, but we only need a few of them and can figure out a generic dispatch mechanism for that. Are there other parts that are significantly different between the two backends?

@neerajprad neerajprad added the discussion Requires broader discussion label Oct 4, 2019
@null-a
Copy link
Collaborator

null-a commented Oct 7, 2019

Indeed! The code generation bits for numpyro started out as a copy and paste of the pyro variant, and hasn't diverged very much. fwiw, one approach I have in mind is to add an extra stage to the "compilation" pipeline, that takes a ModelDesc (a full description of the model, assembled from the formula, data, priors, etc.) and produces something like an AST that describes in a generic way how the model ought to be implemented in terms of probabilistic primitives (sample, observe, etc.) and linear algebra operations. I think this would capture a lot of the overlap we currently have. The individual back ends would then only need to walk this AST turning sample in to e.g. pyro.sample, mat-vec-mul in to e.g. torch.mv, etc.

On the inference side, the two back ends we currently have are now more similar than I suspect they once were (because of Pyro interface changes I think) though differences remain. I think the differences between the two implementations of prior are fairly typical of the differences in general. The pyro variant collects samples by tracing model execution multiple times (not vectorized) and fetches everything it needs (to implement e.g. get_param) from the trace. The numpyro variant traces the model in a vectorized way, but this doesn't provide access to the return value of the model, so the model is re-run with substitute to get hold of those. (IIRC).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Requires broader discussion
Projects
None yet
Development

No branches or pull requests

2 participants