Skip to content

Commit

Permalink
DifferentialEquation Op refactor (#3634)
Browse files Browse the repository at this point in the history
* addition of test for equality checking of ODE Ops
(not yet implemented)

* WIP: refactoring the DifferentialEquation Op
+ full support for test_values
+ explicit input/output types
+ 2D return shape
+ optional return of sensitivities
+ gradient without helper Op

* fully replace DifferentialEquation Op with the refactored implementation

* align tests with refactored API
+ whitespace & condensed formatting
+ test for equality of identical Ops

* use tt.stack as suggested by DeprecationWarning

* always cast y0 and theta to floatX

* allow some tests to fail on float32
(due to downcast exception)

* don't use f-strings to maintain 3.5 support

* link ODE refactor PR

* renamed ODE notebooks
+ add notebooks to examples index

* use (custom) errors instead of asserts

* move ShapeError to exceptions.py
  • Loading branch information
michaelosthege authored and ColCarroll committed Nov 3, 2019
1 parent cc55279 commit 225ae82
Show file tree
Hide file tree
Showing 11 changed files with 979 additions and 814 deletions.
2 changes: 1 addition & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

### New features
- Implemented robust u turn check in NUTS (similar to stan-dev/stan#2800). See PR [#3605]
- Add capabilities to do inference on parameters in a differential equation with `DifferentialEquation`. See [#3590](https://github.com/pymc-devs/pymc3/pull/3590).
- Add capabilities to do inference on parameters in a differential equation with `DifferentialEquation`. See [#3590](https://github.com/pymc-devs/pymc3/pull/3590) and [#3634](https://github.com/pymc-devs/pymc3/pull/3634).
- Distinguish between `Data` and `Deterministic` variables when graphing models with graphviz. PR [#3491](https://github.com/pymc-devs/pymc3/pull/3491).
- Sequential Monte Carlo - Approximate Bayesian Computation step method is now available. The implementation is in an experimental stage and will be further improved.
- Added `Matern12` covariance function for Gaussian processes. This is the Matern kernel with nu=1/2.
Expand Down
410 changes: 410 additions & 0 deletions docs/source/notebooks/ODE_API_introduction.ipynb

Large diffs are not rendered by default.

570 changes: 0 additions & 570 deletions docs/source/notebooks/ODE_API_parameter_estimation.ipynb

This file was deleted.

317 changes: 317 additions & 0 deletions docs/source/notebooks/ODE_API_shapes_and_benchmarking.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Bayesian inference in non-linear ODEs using PyMC3\n",
"# Lotka-Volterra with manual gradients\n",
"\n",
"by [Sanmitra Ghosh](https://www.mrc-bsu.cam.ac.uk/people/in-alphabetical-order/a-to-g/sanmitra-ghosh/)"
]
Expand Down
5 changes: 3 additions & 2 deletions docs/source/notebooks/table_of_contents_examples.js
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Gallery.contents = {
"normalizing_flows_overview": "Variational Inference",
"gaussian-mixture-model-advi": "Variational Inference",
"GLM-hierarchical-advi-minibatch": "Variational Inference",
"ODE_parameter_estimation": "Inference in ODE models",
"ODE_API_parameter_estimation": "Inference in ODE models"
"ODE_with_manual_gradients": "Inference in ODE models",
"ODE_API_introduction": "Inference in ODE models",
"ODE_API_shapes_and_benchmarking": "Inference in ODE models"
}
10 changes: 10 additions & 0 deletions pymc3/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"IncorrectArgumentsError",
"TraceDirectoryError",
"ImputationWarning",
"ShapeError"
]


Expand All @@ -24,3 +25,12 @@ class ImputationWarning(UserWarning):
"""Warning that there are missing values that will be imputed."""

pass


class ShapeError(Exception):
"""Error that the shape of a variable is incorrect."""
def __init__(self, message, actual=None, expected=None):
if expected and actual:
super().__init__('{} (actual {} != expected {})'.format(message, actual, expected))
else:
super().__init__(message)
212 changes: 120 additions & 92 deletions pymc3/ode/ode.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import logging
import numpy as np
import scipy
import theano
import theano.tensor as tt
from ..ode.utils import augment_system, ODEGradop
from ..ode.utils import augment_system
from ..exceptions import ShapeError

_log = logging.getLogger('pymc3')


class DifferentialEquation(theano.Op):
Expand All @@ -17,16 +21,16 @@ class DifferentialEquation(theano.Op):
func : callable
Function specifying the differential equation
t0 : float
Time corresponding to the initial condition
times : array
Array of times at which to evaluate the solution of the differential equation.
n_states : int
Dimension of the differential equation. For scalar differential equations, n_states=1.
For vector valued differential equations, n_states = number of differential equations in the system.
n_odeparams : int
n_theta : int
Number of parameters in the differential equation.
t0 : float
Time corresponding to the initial condition
.. code-block:: python
def odefunc(y, t, p):
Expand All @@ -35,45 +39,49 @@ def odefunc(y, t, p):
times = np.arange(0.5, 5, 0.5)
ode_model = DifferentialEquation(func=odefunc, t0=0, times=times, n_states=1, n_odeparams=1)
ode_model = DifferentialEquation(func=odefunc, times=times, n_states=1, n_theta=1, t0=0)
"""

__props__ = ("func", "t0", "times", "n_states", "n_odeparams")

def __init__(self, func, times, n_states, n_odeparams, t0=0):
_itypes = [
tt.TensorType(theano.config.floatX, (False,)), # y0 as 1D floatX vector
tt.TensorType(theano.config.floatX, (False,)) # theta as 1D floatX vector
]
_otypes = [
tt.TensorType(theano.config.floatX, (False, False)), # model states as floatX of shape (T, S)
tt.TensorType(theano.config.floatX, (False, False, False)), # sensitivities as floatX of shape (T, S, len(y0) + len(theta))
]
__props__ = ("func", "times", "n_states", "n_theta", "t0")

def __init__(self, func, times, *, n_states, n_theta, t0=0):
if not callable(func):
raise ValueError("Argument func must be callable.")
if n_states < 1:
raise ValueError("Argument n_states must be at least 1.")
if n_odeparams <= 0:
raise ValueError("Argument n_odeparams must be positive.")
if n_theta <= 0:
raise ValueError("Argument n_theta must be positive.")

# Public
self.func = func
self.t0 = t0
self.times = tuple(times)
self.n_times = len(times)
self.n_states = n_states
self.n_odeparams = n_odeparams
self.n_theta = n_theta
self.n_p = n_states + n_theta

# Private
self._n = n_states
self._m = n_odeparams + n_states

self._augmented_times = np.insert(times, 0, t0)
self._augmented_func = augment_system(func, self._n, self._m)
self._augmented_func = augment_system(func, self.n_states, self.n_p)
self._sens_ic = self._make_sens_ic()

self._cached_y = None
self._cached_sens = None
self._cached_parameters = None

self._grad_op = ODEGradop(self._numpy_vsp)

# Cache symbolic sensitivities by the hash of inputs
self._apply_nodes = {}
self._output_sensitivities = {}

def _make_sens_ic(self):
"""
The sensitivity matrix will always have consistent form.
If the first n_odeparams entries of the parameters vector in the simulate call
correspond to ode paramaters, then the first n_odeparams columns in
If the first n_theta entries of the parameters vector in the simulate call
correspond to ode paramaters, then the first n_theta columns in
the sensitivity matrix will be 0
If the last n_states entries of the paramters vector in the simulate call
Expand All @@ -83,7 +91,7 @@ def _make_sens_ic(self):
"""

# Initialize the sensitivity matrix to be 0 everywhere
sens_matrix = np.zeros((self._n, self._m))
sens_matrix = np.zeros((self.n_states, self.n_p))

# Slip in the identity matrix in the appropirate place
sens_matrix[:, -self.n_states :] = np.eye(self.n_states)
Expand All @@ -95,89 +103,109 @@ def _make_sens_ic(self):
return dydp

def _system(self, Y, t, p):
"""This is the function that will be passed to odeint. Solves both ODE and sensitivities
"""This is the function that will be passed to odeint. Solves both ODE and sensitivities.
"""

dydt, ddt_dydp = self._augmented_func(Y[: self._n], t, p, Y[self._n :])
dydt, ddt_dydp = self._augmented_func(Y[:self.n_states], t, p, Y[self.n_states:])
derivatives = np.concatenate([dydt, ddt_dydp])
return derivatives

def _simulate(self, parameters):
# Initial condition comprised of state initial conditions and raveled
# sensitivity matrix
y0 = np.concatenate([parameters[self.n_odeparams :], self._sens_ic])
def _simulate(self, y0, theta):
# Initial condition comprised of state initial conditions and raveled sensitivity matrix
s0 = np.concatenate([y0, self._sens_ic])

# perform the integration
sol = scipy.integrate.odeint(
func=self._system, y0=y0, t=self._augmented_times, args=(parameters,)
func=self._system, y0=s0, t=self._augmented_times, args=(np.concatenate([theta, y0]),)
)
# The solution
y = sol[1:, : self.n_states]
y = sol[1:, :self.n_states]

# The sensitivities, reshaped to be a sequence of matrices
sens = sol[1:, self.n_states :].reshape(len(self.times), self._n, self._m)
sens = sol[1:, self.n_states:].reshape(self.n_times, self.n_states, self.n_p)

return y, sens

def _cached_simulate(self, parameters):
if np.array_equal(np.array(parameters), self._cached_parameters):

return self._cached_y, self._cached_sens

return self._simulate(np.array(parameters))

def _state(self, parameters):
y, sens = self._cached_simulate(np.array(parameters))
self._cached_y, self._cached_sens, self._cached_parameters = y, sens, parameters
return y.ravel()

def _numpy_vsp(self, parameters, g):
_, sens = self._cached_simulate(np.array(parameters))

# Each element of sens is an nxm sensitivity matrix
# There is one sensitivity matrix per time step, making sens a (len(times), n_states, len(parameter))
# dimensional array. Reshaping the sens array in this way is like stacking each of the elements of sens on top
# of one another.
numpy_sens = sens.reshape((self.n_states * len(self.times), len(parameters)))
# The dot product here is equivalent to np.einsum('ijk,jk', sens, g)
# if sens was not reshaped and if g had the same shape as yobs
return numpy_sens.T.dot(g)

def make_node(self, odeparams, y0):
if len(odeparams) != self.n_odeparams:
raise ValueError(
"odeparams has too many or too few parameters. Expected {a} parameter(s) but got {b}".format(
a=self.n_odeparams, b=len(odeparams)
)
)
if len(y0) != self.n_states:
raise ValueError(
"y0 has too many or too few parameters. Expected {a} parameter(s) but got {b}".format(
a=self.n_states, b=len(y0)
)
def make_node(self, y0, theta):
inputs = (y0, theta)
_log.debug('make_node for inputs {}'.format(hash(inputs)))
states = self._otypes[0]()
sens = self._otypes[1]()

# store symbolic output in dictionary such that it can be accessed in the grad method
self._output_sensitivities[hash(inputs)] = sens
return theano.Apply(self, inputs, (states, sens))

def __call__(self, y0, theta, return_sens=False, **kwargs):
# convert inputs to tensors (and check their types)
y0 = tt.cast(tt.unbroadcast(tt.as_tensor_variable(y0), 0), theano.config.floatX)
theta = tt.cast(tt.unbroadcast(tt.as_tensor_variable(theta), 0), theano.config.floatX)
inputs = [y0, theta]
for i, (input, itype) in enumerate(zip(inputs, self._itypes)):
if not input.type == itype:
raise ValueError('Input {} of type {} does not have the expected type of {}'.format(i, input.type, itype))

# use default implementation to prepare symbolic outputs (via make_node)
states, sens = super(theano.Op, self).__call__(y0, theta, **kwargs)

if theano.config.compute_test_value != 'off':
# compute test values from input test values
test_states, test_sens = self._simulate(
y0=self._get_test_value(y0),
theta=self._get_test_value(theta)
)

if np.ndim(odeparams) > 1:
odeparams = np.ravel(odeparams)
if np.ndim(y0) > 1:
y0 = np.ravel(y0)

odeparams = tt.as_tensor_variable(odeparams)
y0 = tt.as_tensor_variable(y0)
parameters = tt.concatenate([odeparams, y0])
return theano.Apply(self, [parameters], [parameters.type()])
# check types of simulation result
if not test_states.dtype == self._otypes[0].dtype:
raise TypeError('Simulated states have the wrong type')
if not test_sens.dtype == self._otypes[1].dtype:
raise TypeError('Simulated sensitivities have the wrong type')

# check shapes of simulation result
expected_states_shape = (self.n_times, self.n_states)
expected_sens_shape = (self.n_times, self.n_states, self.n_p)
if not test_states.shape == expected_states_shape:
raise ShapeError('Simulated states have the wrong shape.', test_states.shape, expected_states_shape)
if not test_sens.shape == expected_sens_shape:
raise ShapeError('Simulated sensitivities have the wrong shape.', test_sens.shape, expected_sens_shape)

# attach results as test values to the outputs
states.tag.test_value = test_states
sens.tag.test_value = test_sens

if return_sens:
return states, sens
return states

def perform(self, node, inputs_storage, output_storage):
parameters = inputs_storage[0]
out = output_storage[0]
# get the numerical solution of ODE states
out[0] = self._state(parameters)
y0, theta = inputs_storage[0], inputs_storage[1]
# simulate states and sensitivities in one forward pass
output_storage[0][0], output_storage[1][0] = self._simulate(y0, theta)

def grad(self, inputs, output_grads):
x = inputs[0]
g = output_grads[0]
# pass the VSP when asked for gradient
grad_op_apply = self._grad_op(x, g)
def infer_shape(self, node, input_shapes):
s_y0, s_theta = input_shapes
output_shapes = [(self.n_times, self.n_states), (self.n_times, self.n_states, self.n_p)]
return output_shapes

return [grad_op_apply]
def grad(self, inputs, output_grads):
_log.debug('grad w.r.t. inputs {}'.format(hash(tuple(inputs))))

# fetch symbolic sensitivity output node from cache
ihash = hash(tuple(inputs))
if ihash in self._output_sensitivities:
sens = self._output_sensitivities[ihash]
else:
_log.debug('No cached sensitivities found!')
_, sens = self.__call__(*inputs, return_sens=True)
ograds = output_grads[0]

# for each parameter, multiply sensitivities with the output gradient and sum the result
# sens is (n_times, n_states, n_p)
# ograds is (n_times, n_states)
grads = [
tt.sum(sens[:,:,p] * ograds)
for p in range(self.n_p)
]

# return separate gradient tensors for y0 and theta inputs
result = tt.stack(grads[:self.n_states]), tt.stack(grads[self.n_states:])
return result
31 changes: 8 additions & 23 deletions pymc3/ode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,41 +45,26 @@ def augment_system(ode_func, n, m):

dydp = dydp_vec.reshape((n, m))

# Stack the results of the ode_func
f_tensor = tt.stack(ode_func(t_y, t_t, t_p))
# Stack the results of the ode_func into a single tensor variable
yhat = ode_func(t_y, t_t, t_p)
if not isinstance(yhat, (list, tuple)):
yhat = (yhat,)
t_yhat = tt.stack(yhat, axis=0)

# Now compute gradients
J = tt.jacobian(f_tensor, t_y)
J = tt.jacobian(t_yhat, t_y)

Jdfdy = tt.dot(J, dydp)

grad_f = tt.jacobian(f_tensor, t_p)
grad_f = tt.jacobian(t_yhat, t_p)

# This is the time derivative of dydp
ddt_dydp = (Jdfdy + grad_f).flatten()

system = theano.function(
inputs=[t_y, t_t, t_p, dydp_vec],
outputs=[f_tensor, ddt_dydp],
outputs=[t_yhat, ddt_dydp],
on_unused_input="ignore",
)

return system


class ODEGradop(theano.Op):
def __init__(self, numpy_vsp):
self._numpy_vsp = numpy_vsp

def make_node(self, x, g):

x = theano.tensor.as_tensor_variable(x)
g = theano.tensor.as_tensor_variable(g)
node = theano.Apply(self, [x, g], [g.type()])
return node

def perform(self, node, inputs_storage, output_storage):
x = inputs_storage[0]
g = inputs_storage[1]
out = output_storage[0]
out[0] = self._numpy_vsp(x, g) # get the numerical VSP
Loading

0 comments on commit 225ae82

Please sign in to comment.