Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any insight into the dev model for oryx? #38

Open
femtomc opened this issue Feb 13, 2023 · 4 comments
Open

Any insight into the dev model for oryx? #38

femtomc opened this issue Feb 13, 2023 · 4 comments

Comments

@femtomc
Copy link
Contributor

femtomc commented Feb 13, 2023

Hi all!

Will Oryx continue to be actively maintained? Are there maintainers who are hoping to continue working on the package?

@sharadmv
Copy link
Collaborator

Hey! Oryx is being maintained to work against JAX at HEAD but I'm not working on any new features (I work full time on Jax and Jax triton now).

@femtomc
Copy link
Contributor Author

femtomc commented Feb 13, 2023

@sharadmv Thanks for the info!

Just to be transparent about motivations, I'm building a system based on Gen which uses JAX.

From Oryx's core, I'd like to use the inverse and ildj transforms to support an "exact logpdf" language, whose model objects act like distributions (a bit more about Gen: there is no single modeling language - there's a collection of languages for models, whose objects implement an abstract interface - distributions are one such object, and I'm considering a language for "distributions + ILDJ compat functions" as another such implementor of the interface).

The implementation of these transformations in Oryx seem well designed for this task -- so I'd like to use Oryx (or, at the very least, the conceptual content of Oryx) for the task.

I've also been considering the maintenance/dev model of Oryx -- as I was considering Oryx as a dependency. Depending on information about the maintenance of Oryx:

  • I was considering forking the transformation code for inverse and ildj -- I've used the propagate interpreter for another language previously, so I already have a modified version of that in my codebase -- but I really wanted to wait to do any forking/mangling until I could converse with you. There's parts of the library which I would likely make great use of, and other parts which I don't necessarily think I would use (e.g. I think I wouldn't actually use any of the inference modules, but potentially the nn and optimizer module might be useful, if I'm understanding their value prop - see below).
  • You would have a better understanding than me about this -- do you feel like the current state of Oryx is "ruleset complete" for ILDJ/inverses?

A few other things I've been thinking about in Oryx proper:

  • I'd love to chat about the nn module -- and the intent behind it. One fascinating value proposition would be supporting nn parametrized functions which operate on random variables -- which are also compat with ILDJ. I'm sort of guessing that's what was intended. In Gen's Julia implementation, we constructed an "invertible transformation" distributions DSL - but it's less expressive than Oryx - and I don't think we seriously considered neural networks + ILDJ.

Cool work! I will admit, when I first looked at Oryx - I totally misjudged the conceptual content - only recently did I really appreciate the language value proposition and design.

Thanks for any comments.

@sharadmv
Copy link
Collaborator

Just to be transparent about motivations, I'm building a system based on Gen which uses JAX.

Sounds awesome! Do you have a repo/doc I could read to learn more?

but potentially the nn and optimizer module might be useful, if I'm understanding their value prop - see below

The idea of parameterized, invertible functions is pretty core to Oryx (I have some examples internally of using Oryx nn to implement RealNVP and MAF). However, the nn library is a bit more opinionated than it needs to be. I'd recommend using harvest directly to build your own mini state-management library.

Here's an example one:

# Mini state library

collect = functools.partial(oryx.core.reap, tag=oryx.core.state.VARIABLE)
inject = functools.partial(oryx.core.plant, tag=oryx.core.state.VARIABLE)

@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Module:
  params: Any
  apply: Any

  def __call__(self, *args, **kwargs):
    return self.apply(self.params, *args, **kwargs)

  def tree_flatten(self):
    return (self.params,), (self.apply,)

  @classmethod
  def tree_unflatten(self, data, xs):
    return Module(xs[0], data[0])

@oryx.core.ppl.log_prob.register(Module)
def module_log_prob(module):
  return ppl.log_prob(lambda *args: module.apply(module.params, *args))

def init(f):
  def wrapped(state_key, *args):
    params = collect(f)(state_key, *args)
    return Module(params, inject(partial(f, state_key)))
  return wrapped

The Module's apply could be an invertible function (like a normalizing flow) or an Oryx probabilistic program that we can use log_prob with.

You would have a better understanding than me about this -- do you feel like the current state of Oryx is "ruleset complete" for ILDJ/inverses?

The rules are probably not as complete as I'd like them to be, namely for lack of time/demand. I'm happy to accept new rules in a PR though!

Control flow is a big hole in the rules right now -- inverting something like scan is possible, but highly nontrivial. However, doing so would enable things like HMMs expressed via scan! That was a long-term goal w/ Oryx for me but I have never gotten around to implementing it.

@femtomc
Copy link
Contributor Author

femtomc commented Feb 17, 2023

Sounds awesome! Do you have a repo/doc I could read to learn more?

@sharadmv I've sent you a private email about this (private, because we're still working closed source).

The idea of parameterized, invertible functions is pretty core to Oryx (I have some examples internally of using Oryx nn to implement RealNVP and MAF). However, the nn library is a bit more opinionated than it needs to be. I'd recommend using harvest directly to build your own mini state-management library.

harvest is super neat. I wrote a restricted version of harvest previously -- I wasn't concerned with handling higher-order primitives (partially, in my modeling code - I'm still not, because there's a level of model design modularity which allows me to use higher-order models to support things like vmap or scan, etc).

Control flow is a big hole in the rules right now -- inverting something like scan is possible, but highly nontrivial. However, doing so would enable things like HMMs expressed via scan! That was a long-term goal w/ Oryx for me but I have never gotten around to implementing it.

Right, this is pretty interesting. Because Gen doesn't assume any restrictions on the return value function $f$, there's a straightforward way to support things like logpdf for internal random choices in models which use scan (Gen doesn't assume that the return value function is a transformation whose output you wish to constrain). The way you gain access to scan is to use one of these higher-order models above (which also implement Gen's interface).

I am curious what happens if I use an Oryx model without control flow, which supports logpdf - and then shove it into one of the higher-order models above.

(re -- when I make comments about Gen + Oryx, I'm thinking of Oryx as providing a DSL for defining objects with sample and exact logpdf evaluation - but if an object supports these two interfaces, you can automatically define Gen's interface on it.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants