Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement betainc and derivatives #464

Merged
merged 1 commit into from
Jul 6, 2021

Conversation

ricardoV94
Copy link
Contributor

@ricardoV94 ricardoV94 commented Jun 6, 2021

This PR adds the equivalent Scipy betainc and Python-only Op's for the approximation of the derivatives wrt to the first two arguments. More context can be found in pymc-devs/pymc#4736

One of the scalar gradient tests is failing locally because the expected nan return raises a ValueError in a test context, whereas it issues a RuntimeWarning when running in the REPL. Other scalar gradient tests are failing in the CI due to numerical issues, but pass locally. Probably I need to specify in more detail the compilation mode (and if so, which one)?


Here are a few important guidelines and requirements to check before your PR can be merged:

  • There is an informative high-level description of the changes.
  • The description and/or commit message(s) references the relevant GitHub issue(s).
  • pre-commit is installed and set up.
  • The commit messages follow these guidelines.
  • The commits correspond to relevant logical changes, and there are no commits that fix changes introduced by other commits in the same branch/BR. If your commit description starts with "Fix...", then you're probably making this mistake.
  • There are tests covering the changes introduced in the PR.

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

We need to determine whether or not an Aesara implementation of _betainc_derivative is a reasonable replacement for the Python-only implementation via BetaIncDd[a|b] before merging.

If it's not possible to make an Aesara version that's comparable to the Python version in a reasonable amount of time/effort, then we can create a separate issue for that and merge this in the meantime.

Comment on lines 1119 to 1106
class BetaIncDdb(TernaryScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the second argument (b)
"""

def impl(self, a, b, x):
return _betainc_derivative(a, b, x, wrtp=False)


betainc_ddb_scalar = BetaIncDdb(upgrade_to_float_no_complex, name="betainc_ddb")


def _betainc_derivative(p, q, x, wrtp=True):
Copy link
Member

@brandonwillard brandonwillard Jun 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the/a Aesara implementation of _betainc_derivative prohibitively slow compared to this Python implementation?

If not, we should definitely use the Aesara implementation. If it is, we need to figure out why and open an independent line of investigation for why that's the case, and fix it.

Copy link
Contributor Author

@ricardoV94 ricardoV94 Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually did not write it!

The BetaInc code which was just a copy of the c code behind the scipy.special.betainc was what was known to be prohibitively slow (to compile specially), as were the autodiff derivatives obtained from it. I have some crude benchmarks here: https://github.com/ricardoV94/derivatives_betainc/blob/master/comparison_aesara.ipynb

If we want a test case for exploring the slowness of scan that seems like a good start, as we have the scipy vs aesara with the exact same algorithm under the hood.

The derivatives are a complete different algorithm so they might be fine. Do you think it's worth trying to convert them to aesara code? I guess the concern here is that they break the auto-diff chain? Or is it an issue for the backends that would need custom dispatch?

@ricardoV94 ricardoV94 force-pushed the betainc branch 2 times, most recently from e35a39a to 91d36ac Compare June 7, 2021 17:09
@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Jun 8, 2021

Are we running any of the jobs with float32? I am puzzled as to why the custom tests fail here but pass locally and also pass on the PyMC3 PR.

The original precision of 7 decimals should be fine on float64, whereas for float32 it should be 3. The current is 4 so that could explain it.

Edit: I see now it was float32, it's specified during the create matrix id part of the job. Wonder if that could also be part of the test title. It got me by surprise

@brandonwillard
Copy link
Member

brandonwillard commented Jun 8, 2021

I see know it was float32, it's specified during the create matrix id part of the job. Wonder if that could also be part of the test title. It got me by surprise

Yes, it's also something we need to refactor entirely, because rerunning all the tests under a default of float32 is extremely time consuming and does not provide any additional coverage that couldn't be achieved more directly for a fraction of the time.

@ricardoV94 ricardoV94 force-pushed the betainc branch 2 times, most recently from 9fa58fc to ffb575f Compare June 8, 2021 07:52
@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Jun 9, 2021

I am getting a ValueError in the test that expects a nan return. It passes if I run the exact same code in the REPL, but not in the tests (also fails locally). Can't make much of the traceback

Details: ValueError: Scalar check failed (npy_float64)
test_math.py::TestBetaIncGrad::test_stan_grad_combined FAILED            [100%]
tests/scalar/test_math.py:54 (TestBetaIncGrad.test_stan_grad_combined)
self = <aesara.compile.function.types.Function object at 0x7f411e2dc400>
args = (1.0, 1.0, 1.0), kwargs = {}
restore_defaults = <function Function.__call__.<locals>.restore_defaults at 0x7f411e2aa820>
profile = None, t0 = 1623257249.2276874, output_subset = None, i = 3, arg = 1.0
s = <array(1.)>, c = <array(1.)>

    def __call__(self, *args, **kwargs):
        """
        Evaluates value of a function on given arguments.
    
        Parameters
        ----------
        args : list
            List of inputs to the function. All inputs are required, even when
            some of them are not necessary to calculate requested subset of
            outputs.
    
        kwargs : dict
            The function inputs can be passed as keyword argument. For this, use
            the name of the input or the input instance as the key.
    
            Keyword argument ``output_subset`` is a list of either indices of the
            function's outputs or the keys belonging to the `output_keys` dict
            and represent outputs that are requested to be calculated. Regardless
            of the presence of ``output_subset``, the updates are always calculated
            and processed. To disable the updates, you should use the ``copy``
            method with ``delete_updates=True``.
    
        Returns
        -------
        list
            List of outputs on indices/keys from ``output_subset`` or all of them,
            if ``output_subset`` is not passed.
        """
    
        def restore_defaults():
            for i, (required, refeed, value) in enumerate(self.defaults):
                if refeed:
                    if isinstance(value, Container):
                        value = value.storage[0]
                    self[i] = value
    
        profile = self.profile
        t0 = time.time()
    
        output_subset = kwargs.pop("output_subset", None)
        if output_subset is not None and self.output_keys is not None:
            output_subset = [self.output_keys.index(key) for key in output_subset]
    
        # Reinitialize each container's 'provided' counter
        if self.trust_input:
            i = 0
            for arg in args:
                s = self.input_storage[i]
                s.storage[0] = arg
                i += 1
        else:
            for c in self.input_storage:
                c.provided = 0
    
            if len(args) + len(kwargs) > len(self.input_storage):
                raise TypeError("Too many parameter passed to aesara function")
    
            # Set positional arguments
            i = 0
            for arg in args:
                # TODO: provide a Param option for skipping the filter if we
                #      really want speed.
                s = self.input_storage[i]
                # see this emails for a discuation about None as input
                # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
                if arg is None:
                    s.storage[0] = arg
                else:
                    try:
                        s.storage[0] = s.type.filter(
                            arg, strict=s.strict, allow_downcast=s.allow_downcast
                        )
    
                    except Exception as e:
                        function_name = "aesara function"
                        argument_name = "argument"
                        if self.name:
                            function_name += ' with name "' + self.name + '"'
                        if hasattr(arg, "name") and arg.name:
                            argument_name += ' with name "' + arg.name + '"'
                        where = get_variable_trace_string(self.maker.inputs[i].variable)
                        if len(e.args) == 1:
                            e.args = (
                                "Bad input "
                                + argument_name
                                + " to "
                                + function_name
                                + f" at index {int(i)} (0-based). {where}"
                                + e.args[0],
                            )
                        else:
                            e.args = (
                                "Bad input "
                                + argument_name
                                + " to "
                                + function_name
                                + f" at index {int(i)} (0-based). {where}"
                            ) + e.args
                        restore_defaults()
                        raise
                s.provided += 1
                i += 1
    
        # Set keyword arguments
        if kwargs:  # for speed, skip the items for empty kwargs
            for k, arg in kwargs.items():
                self[k] = arg
    
        if (
            not self.trust_input
            and
            # The getattr is only needed for old pickle
            getattr(self, "_check_for_aliased_inputs", True)
        ):
            # Collect aliased inputs among the storage space
            args_share_memory = []
            for i in range(len(self.input_storage)):
                i_var = self.maker.inputs[i].variable
                i_val = self.input_storage[i].storage[0]
                if hasattr(i_var.type, "may_share_memory"):
                    is_aliased = False
                    for j in range(len(args_share_memory)):
    
                        group_j = zip(
                            [
                                self.maker.inputs[k].variable
                                for k in args_share_memory[j]
                            ],
                            [
                                self.input_storage[k].storage[0]
                                for k in args_share_memory[j]
                            ],
                        )
                        if any(
                            [
                                (
                                    var.type is i_var.type
                                    and var.type.may_share_memory(val, i_val)
                                )
                                for (var, val) in group_j
                            ]
                        ):
    
                            is_aliased = True
                            args_share_memory[j].append(i)
                            break
    
                    if not is_aliased:
                        args_share_memory.append([i])
    
            # Check for groups of more than one argument that share memory
            for group in args_share_memory:
                if len(group) > 1:
                    # copy all but the first
                    for j in group[1:]:
                        self.input_storage[j].storage[0] = copy.copy(
                            self.input_storage[j].storage[0]
                        )
    
        # Check if inputs are missing, or if inputs were set more than once, or
        # if we tried to provide inputs that are supposed to be implicit.
        if not self.trust_input:
            for c in self.input_storage:
                if c.required and not c.provided:
                    restore_defaults()
                    raise TypeError(
                        f"Missing required input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
                    )
                if c.provided > 1:
                    restore_defaults()
                    raise TypeError(
                        f"Multiple values for input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
                    )
                if c.implicit and c.provided > 0:
                    restore_defaults()
                    raise TypeError(
                        f"Tried to provide value for implicit input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
                    )
    
        # Do the actual work
        t0_fn = time.time()
        try:
            outputs = (
>               self.fn()
                if output_subset is None
                else self.fn(output_subset=output_subset)
            )
E           ValueError: Scalar check failed (npy_float64)

../../aesara/compile/function/types.py:976: ValueError

During handling of the above exception, another exception occurred:

self = <tests.scalar.test_math.TestBetaIncGrad object at 0x7f413a7f8d00>

    def test_stan_grad_combined(self):
        a, b, z = aet.scalars("a", "b", "z")
        betainc_out = betainc(a, b, z)
        betainc_grad = aet.grad(betainc_out, [a, b], null_gradients="return")
        f_grad = function([a, b, z], betainc_grad)
    
        for test_a, test_b, test_z, expected_dda, expected_ddb in (
            (1.0, 1.0, 1.0, 0, np.nan),
            (1.0, 1.0, 0.4, -0.36651629, 0.30649537),
        ):
            assert_allclose(
>               f_grad(test_a, test_b, test_z), [expected_dda, expected_ddb]
            )

test_math.py:66: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../aesara/compile/function/types.py:989: in __call__
    raise_with_op(
../../aesara/link/utils.py:522: in raise_with_op
    raise exc_value.with_traceback(exc_trace)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <aesara.compile.function.types.Function object at 0x7f411e2dc400>
args = (1.0, 1.0, 1.0), kwargs = {}
restore_defaults = <function Function.__call__.<locals>.restore_defaults at 0x7f411e2aa820>
profile = None, t0 = 1623257249.2276874, output_subset = None, i = 3, arg = 1.0
s = <array(1.)>, c = <array(1.)>

    def __call__(self, *args, **kwargs):
        """
        Evaluates value of a function on given arguments.
    
        Parameters
        ----------
        args : list
            List of inputs to the function. All inputs are required, even when
            some of them are not necessary to calculate requested subset of
            outputs.
    
        kwargs : dict
            The function inputs can be passed as keyword argument. For this, use
            the name of the input or the input instance as the key.
    
            Keyword argument ``output_subset`` is a list of either indices of the
            function's outputs or the keys belonging to the `output_keys` dict
            and represent outputs that are requested to be calculated. Regardless
            of the presence of ``output_subset``, the updates are always calculated
            and processed. To disable the updates, you should use the ``copy``
            method with ``delete_updates=True``.
    
        Returns
        -------
        list
            List of outputs on indices/keys from ``output_subset`` or all of them,
            if ``output_subset`` is not passed.
        """
    
        def restore_defaults():
            for i, (required, refeed, value) in enumerate(self.defaults):
                if refeed:
                    if isinstance(value, Container):
                        value = value.storage[0]
                    self[i] = value
    
        profile = self.profile
        t0 = time.time()
    
        output_subset = kwargs.pop("output_subset", None)
        if output_subset is not None and self.output_keys is not None:
            output_subset = [self.output_keys.index(key) for key in output_subset]
    
        # Reinitialize each container's 'provided' counter
        if self.trust_input:
            i = 0
            for arg in args:
                s = self.input_storage[i]
                s.storage[0] = arg
                i += 1
        else:
            for c in self.input_storage:
                c.provided = 0
    
            if len(args) + len(kwargs) > len(self.input_storage):
                raise TypeError("Too many parameter passed to aesara function")
    
            # Set positional arguments
            i = 0
            for arg in args:
                # TODO: provide a Param option for skipping the filter if we
                #      really want speed.
                s = self.input_storage[i]
                # see this emails for a discuation about None as input
                # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
                if arg is None:
                    s.storage[0] = arg
                else:
                    try:
                        s.storage[0] = s.type.filter(
                            arg, strict=s.strict, allow_downcast=s.allow_downcast
                        )
    
                    except Exception as e:
                        function_name = "aesara function"
                        argument_name = "argument"
                        if self.name:
                            function_name += ' with name "' + self.name + '"'
                        if hasattr(arg, "name") and arg.name:
                            argument_name += ' with name "' + arg.name + '"'
                        where = get_variable_trace_string(self.maker.inputs[i].variable)
                        if len(e.args) == 1:
                            e.args = (
                                "Bad input "
                                + argument_name
                                + " to "
                                + function_name
                                + f" at index {int(i)} (0-based). {where}"
                                + e.args[0],
                            )
                        else:
                            e.args = (
                                "Bad input "
                                + argument_name
                                + " to "
                                + function_name
                                + f" at index {int(i)} (0-based). {where}"
                            ) + e.args
                        restore_defaults()
                        raise
                s.provided += 1
                i += 1
    
        # Set keyword arguments
        if kwargs:  # for speed, skip the items for empty kwargs
            for k, arg in kwargs.items():
                self[k] = arg
    
        if (
            not self.trust_input
            and
            # The getattr is only needed for old pickle
            getattr(self, "_check_for_aliased_inputs", True)
        ):
            # Collect aliased inputs among the storage space
            args_share_memory = []
            for i in range(len(self.input_storage)):
                i_var = self.maker.inputs[i].variable
                i_val = self.input_storage[i].storage[0]
                if hasattr(i_var.type, "may_share_memory"):
                    is_aliased = False
                    for j in range(len(args_share_memory)):
    
                        group_j = zip(
                            [
                                self.maker.inputs[k].variable
                                for k in args_share_memory[j]
                            ],
                            [
                                self.input_storage[k].storage[0]
                                for k in args_share_memory[j]
                            ],
                        )
                        if any(
                            [
                                (
                                    var.type is i_var.type
                                    and var.type.may_share_memory(val, i_val)
                                )
                                for (var, val) in group_j
                            ]
                        ):
    
                            is_aliased = True
                            args_share_memory[j].append(i)
                            break
    
                    if not is_aliased:
                        args_share_memory.append([i])
    
            # Check for groups of more than one argument that share memory
            for group in args_share_memory:
                if len(group) > 1:
                    # copy all but the first
                    for j in group[1:]:
                        self.input_storage[j].storage[0] = copy.copy(
                            self.input_storage[j].storage[0]
                        )
    
        # Check if inputs are missing, or if inputs were set more than once, or
        # if we tried to provide inputs that are supposed to be implicit.
        if not self.trust_input:
            for c in self.input_storage:
                if c.required and not c.provided:
                    restore_defaults()
                    raise TypeError(
                        f"Missing required input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
                    )
                if c.provided > 1:
                    restore_defaults()
                    raise TypeError(
                        f"Multiple values for input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
                    )
                if c.implicit and c.provided > 0:
                    restore_defaults()
                    raise TypeError(
                        f"Tried to provide value for implicit input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
                    )
    
        # Do the actual work
        t0_fn = time.time()
        try:
            outputs = (
>               self.fn()
                if output_subset is None
                else self.fn(output_subset=output_subset)
            )
E           ValueError: Scalar check failed (npy_float64)
E           Apply node that caused the error: mul(second.0, betainc_ddb.0)
E           Toposort index: 9
E           Inputs types: [Scalar(float64), Scalar(float64)]
E           Inputs shapes: [(), 'No shapes']
E           Inputs strides: [(), 'No strides']
E           Inputs values: [1.0, nan]
E           Outputs clients: [[TensorFromScalar(mul.0)]]
E           
E           Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
E             File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1441, in <listcomp>
E               rval = [access_grad_cache(elem) for elem in wrt]
E             File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1394, in access_grad_cache
E               term = access_term_cache(node)[idx]
E             File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1059, in access_term_cache
E               output_grads = [access_grad_cache(var) for var in node.outputs]
E             File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1059, in <listcomp>
E               output_grads = [access_grad_cache(var) for var in node.outputs]
E             File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1394, in access_grad_cache
E               term = access_term_cache(node)[idx]
E             File "/home/ricardo/Documents/Projects/aesara/aesara/gradient.py", line 1221, in access_term_cache
E               input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
E             File "/home/ricardo/Documents/Projects/aesara/aesara/scalar/basic.py", line 1138, in L_op
E               return self.grad(inputs, output_gradients)
E             File "/home/ricardo/Documents/Projects/aesara/aesara/scalar/math.py", line 1094, in grad
E               gz * betainc_ddb_scalar(a, b, x),
E           
E           HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

../../aesara/compile/function/types.py:976: ValueError

@brandonwillard
Copy link
Member

brandonwillard commented Jun 19, 2021

I am getting a ValueError in the test that expects a nan return. It passes if I run the exact same code in the REPL, but not in the tests (also fails locally). Can't make much of the traceback

That nan isn't a valid value for that scalar input in the C implementation of aesara.scalar.basic.Scalar. For instance, if you run the test in Python mode (e.g. config.change_flags(cxx="")), it should pass.

@ricardoV94
Copy link
Contributor Author

That nan isn't a valid value for that scalar input in the C implementation of aesara.scalar.basic.Scalar. For instance, if you run the test in Python mode (e.g. config.change_flags(cxx="")), it should pass.

You are right. Is this something that should be worked around, and if so, how?

@ricardoV94 ricardoV94 force-pushed the betainc branch 3 times, most recently from 532a8b9 to cdb12ca Compare June 21, 2021 09:04
@brandonwillard brandonwillard force-pushed the betainc branch 2 times, most recently from 1ece906 to 41be515 Compare July 1, 2021 21:53
@brandonwillard
Copy link
Member

We can perform the test with the nan input using the Elemwise version of betainc, because TensorTypes support nans. I just added this change and rebased.

@codecov
Copy link

codecov bot commented Jul 2, 2021

Codecov Report

Merging #464 (5388d61) into main (b5313f1) will increase coverage by 0.05%.
The diff coverage is 99.09%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #464      +/-   ##
==========================================
+ Coverage   76.66%   76.71%   +0.05%     
==========================================
  Files         148      148              
  Lines       46400    46510     +110     
  Branches    10202    10213      +11     
==========================================
+ Hits        35573    35682     +109     
  Misses       8219     8219              
- Partials     2608     2609       +1     
Impacted Files Coverage Δ
aesara/scalar/math.py 83.91% <99.07%> (+3.52%) ⬆️
aesara/tensor/inplace.py 100.00% <100.00%> (ø)
aesara/tensor/math.py 88.75% <100.00%> (+0.01%) ⬆️

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Jul 3, 2021

All tests are passing now and coverage looks good.

The grad was very sensitive to dtypes when using the Scalar Op (not only for nans) but it's fine when using the tensor Elemwise version.

Calling the derivative scalar ops directly is also fine but it seems more reasonable to test the derivatives via the grad as this is how they will be used.

I would open an issue to test an aesara pure implementation of the derivatives and merge this for the time being (if the code looks good)

Comment on lines 1082 to 1103
class BetaIncDda(TernaryScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the first argument (a)
"""

def impl(self, a, b, x):
return _betainc_derivative(a, b, x, wrtp=True)


betainc_dda_scalar = BetaIncDda(upgrade_to_float_no_complex, name="betainc_dda")


class BetaIncDdb(TernaryScalarOp):
"""
Gradient of the regularized incomplete beta function wrt to the second argument (b)
"""

def impl(self, a, b, x):
return _betainc_derivative(a, b, x, wrtp=False)


betainc_ddb_scalar = BetaIncDdb(upgrade_to_float_no_complex, name="betainc_ddb")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before we merge this, let's combine these into a single Op with a boolean wrtp attribute. The _betainc_derivative function can then become the entire impl method.

Copy link
Contributor Author

@ricardoV94 ricardoV94 Jul 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, is my latest push what you had in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you refactor _betainc_derivative to be BetaIncDdb.impl, then yes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

@ricardoV94 ricardoV94 force-pushed the betainc branch 4 times, most recently from 1e2d674 to cca785f Compare July 5, 2021 07:47
@brandonwillard brandonwillard merged commit 2b78c67 into aesara-devs:main Jul 6, 2021
@ricardoV94 ricardoV94 deleted the betainc branch July 8, 2021 08:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants