Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor old Distribution base class #5308

Open
Tracked by #7053
ricardoV94 opened this issue Jan 3, 2022 · 31 comments
Open
Tracked by #7053

Refactor old Distribution base class #5308

ricardoV94 opened this issue Jan 3, 2022 · 31 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 3, 2022

PyMC distribution classes are weird objects that hold RandomVariables, logp, logcdf and moment methods together (basically doing runtime dispatching) and manage most of the non-RandomVariable kwargs that users are familiar with (observed, transformed, size/dims) and behind the scenes actions like registration in the model.

This exists mostly for backwards compatibility with V3 and ease of developer refactoring, but the current result is far from pretty.

We need to figure out a more elegant/permanent architecture now that many things that existed to accommodate V3 limitations no longer hold.

Distribution

Distribution is currently performing the following tasks:

class Distribution(metaclass=DistributionMeta):

  1. Input validation:
    1. Raising FutureWarnings for testval kwarg
    2. Raising TypeError when distribution is initialized outside of a Model context
    3. Raising TypeError when name is not given to a distribution
    4. Raising ValueError when more than one of dims/shape/size is given
  2. Convert alternative parametrizations to standard parametrization (e.g, tau -> sigma). This is done by the .dist methods.
  3. Add informative attribute errors for deprecated logp, logcdf, random methods
  4. Resize the final RV based on observed, shape, dims or size
  5. Provides the .dist() API to create an unnamed RV that is not registered in the model. This type of variables is necessary for use in Potentials and other distribution factories that use RVs as building blocks such as Bound and Censored distributions, as well as Mixtures and Timeseries once they get refactored for V4

DistributionMeta

In addition we have a DistributionMeta that does the following:

class DistributionMeta(ABCMeta):

  1. Dispatch the logp, logcdf, moment, default_transform methods defined in the old PyMC distributions to apply to the respective rv_op
  2. Register the rv_op type as subclass of the old style PyMC distribution, so that V3 Discrete/Continuous subclass checks still work?
isinstance(pm.Normal.dist().owner.op, pm.distributions.Continuous)  # True

If we want to get rid of Distribution we probably need to statically dispatch our methods to the respective rv_op. That is nothing special, and is how we do it for aeppl from the get go: https://github.com/aesara-devs/aeppl/blob/38d0c2ea4ecf8505f85317047089ab9999d2f78e/aeppl/logprob.py#L104-L130

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 3, 2022

They serve the same goal the "standard" Distribution class is doing right now: basically manage all non distribution kwargs: shape/dims, transformed, observed; and behind the scenes magic things like value variables and rngs

The base Distribution class was only ever supposed to make the hand-off to the Model class, which would then handle all the generic tasks (e.g. value variables, RNGs, shape/dims, etc.), and the rest was just for backward compatibility (e.g. the now completely unnecessary Disribution.dist interface). Notice how instances of these classes are never made, and that only static/class/type-level functions are used.

By extending Distribution and adding more logic to those classes, it could become considerably more difficult to unwind those temporary backward compatibility-only choices, and they'll eventually become a permanent and confusing part of v4's design.

This isn't the time or place to start addressing all that, but these are the kinds of design considerations that need to go alongside changes/additions to the relevant areas of code (i.e. Distribution-related code).

Originally posted by @brandonwillard in #5169 (comment)

@fonnesbeck
Copy link
Member

Do we have a spec in mind for a replacement class? Maybe an explicit list of what needs to be removed and added with respect to the current class would be a good place to start.

@ricardoV94
Copy link
Member Author

Do we have a spec in mind for a replacement class? Maybe an explicit list of what needs to be removed and added with respect to the current class would be a good place to start.

I will update the issue tomorrow with those

@canyon289
Copy link
Member

to accommodate V3 limitations no longer hold.
If youre easily able to list these no longer valid limitations while writing the doc I would also appreciate that (to help with my personal understanding). No obligation if itll take you a lot of time

@ricardoV94
Copy link
Member Author

Updated the top post to mention every function of Distribution classes that I could find. Also highlighted some of the points I think will be a bit more challenging, or may need more drastic API changes.

@ricardoV94 ricardoV94 changed the title Refactor old pymc.distribution base class Refactor old Distribution base class Jan 4, 2022
@fonnesbeck
Copy link
Member

eval() should be aliased to something more intuitive like sample() or draw() where appropriate.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 4, 2022

eval() should be aliased to something more intuitive like sample() or draw() where appropriate.

eval is just the standard Aesara debug feature that exists for any node. We can have a function wrapper that compiles the "proper" aesara function and takes a given number of draws. For analogy with V3 it could be named pm.random, but I like the draw name better.

The question of RNGs with default updates exists anyway, but could arguably be offloaded to pymc.aesaraf.compile_pymc which already does this for Simulator variables anyway, and which pm.random/ pm.draw would call. That would be one less thing that Distribution has to be concerned with.

# Set the default update of a NoDistribution RNG so that it is automatically

@fonnesbeck
Copy link
Member

Yeah, I like a functional approach there, and I agree draw is better than sample since the latter is already associated with MCMC sampling.

@michaelosthege
Copy link
Member

What is the selling point of it?
It's a big big refactoring that forces us into an API breakage.
And after adding all those decorators we'd probably end up with more lines of code per distribution; considerably harder to comprehend.


First of all, if we create no instances of any PyMC distribution. Why do they exist at all?

We distinguish between variables that are

  • registered by name in a model: pm.MyDistribution("name", some, params) or
  • created as unregistered, auxiliary variables, that may still contribute to a model graph: pm.MyDistribution.dist(some params)

Pure Aesara RV Ops don't make this distinction (point 7).

Then PyMC distributions have a much richer API & behavior compared to Aesara RandomVariable Ops (points 1,2,3,5).

So we need to wrap the Op.
But should these adapters be a class MyDistribution(...) or a def MyDistribution(...)?

Neither MyDistribution nor MyDistribution.dist() seem to get much love, and generally it feels like there's a recent hype for functional API designs.
But class vs. function aside, let's consider just the resulting user API syntax:

  1. pm.Normal("name", 0, 1) and pm.Normal.dist(0, 1) → the current API. Done with class Normal(pm.Distribution) because otherwise we'd be monkey-patching things onto functions and nobody wants that. We can still rename .dist() to .cool() or something.
  2. pm.Normal("name", 0, 1) and pm.Normal(0, 1) → doesn't work as long as name, mu, sd etc. are positional. Same for having a kwarg like pm.Normal(0, 1, register=False). In C# this would be trivial, by the way.
  3. pm.Normal("name", 0, 1) and pm.normal(0, 1) → Both could be functions, but maybe also easy to confuse?

So we could move away from class MyDistribution(Distribution) only if we switch to syntax 3.

Then every distribution needs to get a def mydistribution(...) and def MyDistribution(...).
The latter could be a one-liner: MyDistribution = pm.as_rich_RV(mydistribution).

Then we'd have to dispatch logp and get_moment etc. onto the RV Op as @ricardoV94 suggested.
The resulting code would no longer be grouped by distribution, and have a ton of dispatch decorators doing the job that's currently done by DistributionMeta.


  1. Taking care of default updates for the RandomState variables so that returned variables "look" random by default.

This once bit me really hard when I tried to demo Aesara to someone.
I wanted to show how conveniently one can do some symbolic calculations and .eval() things as random numbers.
...took me an hour to understand that I was use to a really awesome behavior that was actually a PyMC, not an Aesara feature.

If Aesara made the non-deterministic .eval() the default---like with PyMC RVs---it would be a lot more beginner friendly & interesting for people doing, for example Monte Carlo sampling.

@ricardoV94
Copy link
Member Author

pm.Normal("name", 0, 1) and pm.Normal(0, 1) → doesn't work as long as name, mu, sd etc. are positional. Same for having a kwarg like pm.Normal(0, 1, register=False). In C# this would be trivial, by the way.

Can't we just check if the first variable is a string and react accordingly?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 9, 2022

Then we'd have to dispatch logp and get_moment etc. onto the RV Op as @ricardoV94 suggested.
The resulting code would no longer be grouped by distribution, and have a ton of dispatch decorators doing the job that's currently done by DistributionMeta.

We already do that anyways, just via the cryptic MetaDistribution. Structurally, what difference does it make if the logp and logcdf are fake methods inside a fake classes or real functions one after the other?

Does this seem so bad?
https://github.com/aesara-devs/aeppl/blob/333108143a1f4b63d9fcd9842dbe35457025c180/aeppl/logprob.py#L104-130

More importantly users should not be calling these pseudo methods themselves because they expect as inputs the already parsed and symbolic canonical parameters of the distribution. For some (many) distributions these have nothing to do with what they would pass into .dist, so making those accessible from the distribution classes (via autocompletion) is actually error prone.

@mitch-at-orika
Copy link

Posting this incase it helps... I came across an issue like this one on discourse while trying to use v3 code in v4. Trying to do some of the plots in BCB the authors use FreeRV.distribution.all_trees[..].predict_output() in creating insightful plots. In v4 the error is 'TensorVariable' object has no attribute 'distribution'.

I tried to solve but it seems in v3, model.py used def var() and returned a <>RV pymc object where distribution=dist. However in v4 we have def register_rv() which returns an aesara tensor variable? Are we breaking this functionality in v4? Sorry if it is something obvious I am missing.

@michaelosthege
Copy link
Member

@mitch-at-orika yes, there is a breaking change w.r.t. what's returned. In most cases that's not a problem, but it looks like BART monkey-patched some information onto the TensorVariable. This is no longer possible with v4 where these calls don't necessarily return the same tensor that was created by the underlying distribution.

@aloctavodia I believe you can link to relevant BART issues, or even explain how to access the all_trees object in v4?

@mitch-at-orika
Copy link

Thanks for the explanation Michael, it is reassuring it was a BART only option, I originally thought I had just missed this functionality of pymc vars until now.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 30, 2024

Here is some pseudo-code that might suffice?

from functools import partial, wraps
import pytensor.tensor.random.basic as ptr

from pymc.distributions.continuous import get_tau_sigma
from pymc.pytensorf import convert_observed_data
from pymc.distributions.shape_utils import convert_dims, shape_from_dims
from pymc.model import modelcontext
from pymc.util import UNSET


def handle_shape(ndim_supp=None):
    """Convert the shape argument to size used by PyTensor."""

    def inner_decorator(dist):

        @wraps(dist)
        def inner_func(*args, size=None, shape=None, **kwargs):
            if shape is not None and size is not None:
                raise ValueError("Cannot pass both size and shape")
            if shape is not None:
                # If needed, call dist without size to find out ndim_supp
                local_ndim_supp = dist(*args).owner.op.ndim_supp if ndim_supp is None else ndim_supp
                size = shape if local_ndim_supp == 0 else shape[:-local_ndim_supp] 
            return dist(*args, size=size, **kwargs)

        return inner_func

    return inner_decorator


def register_model_rv(dist, rv_type=None):
    """Register a random variable in a model context."""

    @wraps(dist)
    def inner_func(name, *args, dims=None, transform=UNSET, observed=None, model=None, **kwargs):
        if dims is not None:
            dims = convert_dims(dims)
        if observed is not None:
            observed = convert_observed_data(observed)

        # The shape of the variable is determined from the following sources:
        # size or shape, otherwise dims, otherwise observed.
        if kwargs.get("size") is None and kwargs.get("shape") is None:
            if dims is not None:
                kwargs["shape"] = shape_from_dims(dims, model)
            elif observed is not None:
                kwargs["shape"] = tuple(observed.shape)
                
        rv = dist(*args, **kwargs)
        model = modelcontext(model)
        return model.register_rv(rv, name=name, dims=dims, transform=transform, observed=observed)

    # Monkey-patch useful attributes
    if rv_type is not None:
        inner_func.rv_type = rv_type
    inner_func.dist = dist
    return inner_func


@handle_shape(ndim_supp=0)
def normal_dist(mu=0, sigma=None, tau=None, **kwargs):
    _, sigma = get_tau_sigma(sigma=sigma, tau=tau)
    return ptr.normal(mu, sigma, **kwargs)

Normal = register_model_rv(normal_dist, rv_type=ptr.NormalRV)

This also make writing distribution helpers simpler. Right now we have to define a redundant __new__ method, to preserve the usual API:

def __new__(cls, name, w, mu, sigma=None, tau=None, comp_shape=(), **kwargs):
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return Mixture(name, w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)
@classmethod
def dist(cls, w, mu, sigma=None, tau=None, comp_shape=(), **kwargs):
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return Mixture.dist(w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)

Instead this could be done like this:

from pymc.distributions.mixture import Mixture

def normal_mixture_dist(w, mu, sigma=None, tau=None, **kwargs):
    return Mixture.dist(w, Normal.dist(mu, sigma=sigma, tau=tau), **kwargs)

NormalMixture = register_model_rv(normal_mixture_dist)

@thomasaarholt
Copy link
Contributor

@ricardoV94 I think I understand this, but just to be certain:

You have written a function that constructs the correct pytensor TensorVariable, and then you have a wrapper class that associates that variable with whatever model context manager this is created within.

Yes, I believe this should work. Here is an example of how you would type register_model_rv in order to propogate the function signature of normal_dist onto Normal.

@thomasaarholt
Copy link
Contributor

I think this introduces another potential problem, however. This new Normal object does not have the .dist method or any other methods or properties associated with it. Or perhaps that's ok, and this would be a breaking change?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 30, 2024

I think this introduces another potential problem, however. This new Normal object does not have the .dist method or any other methods or properties associated with it. Or perhaps that's ok, and this would be a breaking change?

They are being monkey-patched here:

    # Monkey-patch useful attributes
    if rv_type is not None:
        inner_func.rv_type = rv_type
    inner_func.dist = dist
    return inner_func

I am not sure that's the best approach, but the current fake classes also seem odd. Maybe what's done by Distribution.__new__() now should be done by Distribution().__call__()? It's still a pretty useless object with only static methods, but it binds the two methods more transparently

There is no other method that should be attached to the Normal, including logp and logcdf. Those should all be accessed by pm.logp() and alike, because their signature is an implementation detail (i.e., which canonical parametrization we decide to use) that the user shouldn't need to be aware of. In that regard it's a good thing that something like Normal.logp will cease to exist.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 30, 2024

Here is a non-fake class that does the same:

import pytensor.tensor.random.basic as ptr

from pymc.distributions.continuous import get_tau_sigma
from pymc.pytensorf import convert_observed_data
from pymc.distributions.shape_utils import convert_dims, shape_from_dims
from pymc.model import modelcontext
from pymc.util import UNSET


class Distribution:
    rv_type = None
    rv_op = None

    @classmethod
    def dist(cls, *args, size=None, shape=None, **kwargs):
        if shape is not None and size is not None:
            raise ValueError("Cannot pass both size and shape")
        if shape is not None:
            ndim_supp = getattr(cls.rv_type, "ndim_supp", None)
            if ndim_supp is None:
                # If needed, call dist without size to find out ndim_supp
                ndim_supp = dist(*args).owner.op.ndim_supp
            size = shape if ndim_supp == 0 else shape[:-ndim_supp]
        return cls.rv_op(*args, size=size, **kwargs)
    
    def __call__(self, name, *args, dims = None, transform = UNSET, observed = None, model = None, ** kwargs):
        if dims is not None:
            dims = convert_dims(dims)
        if observed is not None:
            observed = convert_observed_data(observed)

        # The shape of the variable is determined from the sources:
        # size or shape, otherwise dims, otherwise observed.
        if kwargs.get("size") is None and kwargs.get("shape") is None:
            if dims is not None:
                kwargs["shape"] = shape_from_dims(dims, model)
            elif observed is not None:
                kwargs["shape"] = tuple(observed.shape)

        rv = self.dist(*args, **kwargs)
        model = modelcontext(model)
        return model.register_rv(rv, name=name, dims=dims, transform=transform, observed=observed)


class NormalDist(Distribution):
    rv_type = ptr.NormalRV
    
    @staticmethod
    def rv_op(mu=0, sigma=None, tau=None, **kwargs):
        _, sigma = get_tau_sigma(tau=tau, sigma=sigma)
        return ptr.normal(mu, sigma, **kwargs)


class NormalMixtureDist(Distribution):
    # If we subclass from a refactord `Mixture`, this `rv_type` would be obtained automatically
    rv_type = Mixture.rv_type
    
    @staticmethod
    def rv_op(w, mu, sigma=None, tau=None, **kwargs):
        _, sigma = get_tau_sigma(tau=tau, sigma=sigma)
        return Mixture.dist(w, Normal.dist(mu=mu, sigma=sigma), **kwargs)
    

Normal = NormalDist()
NormalMixture = NormalMixtureDist()

The sole point of it is that it provides a Normal.dist method! Nothing else

@thomasaarholt
Copy link
Contributor

thomasaarholt commented Jan 30, 2024

I believe we desire two interfaces per distribution that require similar (but not identical) signatures:
We want a dist function, and a "create a distribution and register it with a model" function. Is that a correct interpretation?

For the case of the normal distribution, and only using a minimal number of variables to represent the difference:

def normal_dist(mu, sigma, **kwargs):
    ...
    
def Normal(name, mu, sigma, dims, observed, **kwargs):
    ...

I do not believe there is a way to programmatically produce a static function signature for Normal by using the one for normal_dist through such mechanisms as ParamSpec, as discussed in this gist.

In this case Normal is a function, but it could easily be a class like it is today, and we could keep .dist as its method. I like your suggestions for a simpler inheritance structure. But wouldn't it be better to stick the contents of __call__ inside __init__ instead? Then rename NormalDist to Normal and call super().__init__() inside that classes __init__? Then we only refer to Distribution and Normal, rather than Distribution, NormalDist and Normal.

The only solution that I think makes sense is to type the signature out in both cases, using mechanisms like TypedDict to keep kwargs nice and minimize the amount of code duplication. I'm working on a MVP in a gist to demonstrate.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 30, 2024

But wouldn't it be better to stick the contents of call inside init instead?

__init__ is not allowed to return anything other than None. PyMC distributions are returning completely new objects (TensorVariables), and in that sense __new__ as done now is appropriate although an overkill because it's just a fancy function call.

I believe we desire two interfaces per distribution that require similar (but not identical) signatures:
We want a dist function, and a "create a distribution and register it with a model" function. Is that a correct interpretation?

Yes that's correct

@thomasaarholt
Copy link
Contributor

thomasaarholt commented Jan 31, 2024

Ah yes, that makes sense! Thank you for clarifying these things for me!

@thomasaarholt
Copy link
Contributor

I took a stab at rewriting the Normal class as a function (def Normal), and wrote it up in this gist.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 29, 2024

I took a stab at rewriting the Normal class as a function (def Normal), and wrote it up in this gist.

Looks good. However, the whole code inside Normal is distribution-agnostic so it should just be some decorator in the end? Something like Normal = [wrap_dist|model_var_from_dist](Normal_dist) that copies / extends the docstrings from dist (if that's possible)?

Also how bad would be to monkey_patch Normal.dist = Normal_dist to keep backward compat?

@thomasaarholt
Copy link
Contributor

thomasaarholt commented Oct 30, 2024

These are good points! I haven't quite worked out a decorator approach, but I figured out a way to keep the current API while refactoring things to a "function-based" approach. Normal is still a class, but used quite differently.

The key part is Normal = NormalDistribution(), and Normal is the thing that is exposed as pm.Normal. Take a look! (Note, I am using the master branch for the imports)

@ricardoV94
Copy link
Member Author

You need to consider observed size before you create the dist so that pm.Normal("x", observed=np.zeroes(5)) has the right shape.

Anyway the downside of your last approach is you need to manually duplicate the signature of the dist and call. Also it seems not much better to initialize Normal=NormalDistribution() since we never really use the class but I see the advantage that dist now exists transparently in it. Should be a staticmethod though since self isn't ever used.

@thomasaarholt
Copy link
Contributor

the downside of your last approach is you need to manually duplicate the signature of the dist and call

Assuming that our goal is to have a function signature for the Normal and Normal_dist functions that can help people correctly pass arguments to these functions, then I think there is no way around that. I don't find this particularly problematic though.

I've refactored the gist into a small repo, just to be able to spread the functions and classes into separate modules.
Here is the current version of the normal distribution, and here is an attempt at a decorator version - but I don't think this is a good idea. It adds a lot of boilerplate code with little gain.

@ricardoV94
Copy link
Member Author

The idea of the decorator was to wrap dist, so you only have to define one of the objects (dist in this case).

We don't need classes if the only reason for them is so that Normal.dist exists. I rather just monkey patch tbh. There's nothing "class"ey about them otherwise. They don't hold state or do anything.

@thomasaarholt
Copy link
Contributor

Right, but the problem with monkey patching is that we won't get autocompletion of the Normal.dist monkey-patched function when importing Normal. Within the same file, pyright will give you the docstring and arguments:
Screenshot 2024-10-30 at 22 02 37

But if we monkey patch in one file, and then import it in another file, then pyright won't tell us that the method exists. Here I define it on the left and import it in the file on the right.
image

We don't get autocomplete, arguments or docstring:
Screenshot 2024-10-30 at 22 12 35

This is why the class solution is necessary - it's the only one that will let us "attach" stuff to the object in a manner that type checkers will recognize.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 30, 2024

Right, but then what changed in your previous attempt? We moved __new__ into __call__ and duplicated the signatures between __call__ and dist?

Not saying it's worse, just double checking what we're trading off

@thomasaarholt
Copy link
Contributor

thomasaarholt commented Oct 31, 2024

Right, I'm interpreting your question as something like "We wanted to move from a class-based approach to a function-based approach, but now we have moved from one class approach to another one, how is this better?". I think that's a very valid question.

Just to summarize, the goals I see with this github issue are:

  1. Get rid of the complicated metaclass
  2. Add proper arguments to the distributions, resolving Add distribution parameters as named positional arguments #6083
  3. Keep the previous API of e.g. Normal() and Normal.dist() (keep the dist method on the Normal object).

So, in pymc today, Normal.dist has a proper function signature, while Normal doesn't (it uses *args) - because it gets intercepted by the metaclass's __new__ method.

Using that metaclass approach, we have no way to statically (by that I mean "in a manner in which the type checker can infer without running the code") give our Distribution named arguments (e.g. mu and sigma). I've tried a ton of things in the past.

Using only functions (def normal and def normal_dist) we could do everything except keep the previous API as per my previous comment above about the problems with monkey-patching.

My solution solves the problem with monkey-patching the function approach by statically defining the dist method on the class, so that it is well-defined at class definition time. You can think of my class-approach as more like a API factory for keeping the same API.

The code flow is also completely linear: The Distribution-class approach is nonlinear because it defines rv_op on the subclass, but refers to it in the superclass and metaclass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants