-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Any insight into the dev model for oryx
?
#38
Comments
Hey! Oryx is being maintained to work against JAX at HEAD but I'm not working on any new features (I work full time on Jax and Jax triton now). |
@sharadmv Thanks for the info! Just to be transparent about motivations, I'm building a system based on Gen which uses JAX. From Oryx's core, I'd like to use the The implementation of these transformations in Oryx seem well designed for this task -- so I'd like to use Oryx (or, at the very least, the conceptual content of Oryx) for the task. I've also been considering the maintenance/dev model of Oryx -- as I was considering Oryx as a dependency. Depending on information about the maintenance of Oryx:
A few other things I've been thinking about in Oryx proper:
Cool work! I will admit, when I first looked at Oryx - I totally misjudged the conceptual content - only recently did I really appreciate the language value proposition and design. Thanks for any comments. |
Sounds awesome! Do you have a repo/doc I could read to learn more?
The idea of parameterized, invertible functions is pretty core to Oryx (I have some examples internally of using Oryx Here's an example one: # Mini state library
collect = functools.partial(oryx.core.reap, tag=oryx.core.state.VARIABLE)
inject = functools.partial(oryx.core.plant, tag=oryx.core.state.VARIABLE)
@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Module:
params: Any
apply: Any
def __call__(self, *args, **kwargs):
return self.apply(self.params, *args, **kwargs)
def tree_flatten(self):
return (self.params,), (self.apply,)
@classmethod
def tree_unflatten(self, data, xs):
return Module(xs[0], data[0])
@oryx.core.ppl.log_prob.register(Module)
def module_log_prob(module):
return ppl.log_prob(lambda *args: module.apply(module.params, *args))
def init(f):
def wrapped(state_key, *args):
params = collect(f)(state_key, *args)
return Module(params, inject(partial(f, state_key)))
return wrapped The
The rules are probably not as complete as I'd like them to be, namely for lack of time/demand. I'm happy to accept new rules in a PR though! Control flow is a big hole in the rules right now -- inverting something like |
@sharadmv I've sent you a private email about this (private, because we're still working closed source).
Right, this is pretty interesting. Because Gen doesn't assume any restrictions on the return value function I am curious what happens if I use an Oryx model without control flow, which supports (re -- when I make comments about Gen + Oryx, I'm thinking of Oryx as providing a DSL for defining objects with |
Hi all!
Will
Oryx
continue to be actively maintained? Are there maintainers who are hoping to continue working on the package?The text was updated successfully, but these errors were encountered: