Skip to content

Commit c7a41c3

Browse files
authored
Merge pull request #3590 from Dpananos/gsoc_ode
Add Differential Equation API
2 parents 2f1d0fb + 1fae10d commit c7a41c3

9 files changed

+1405
-199
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
## PyMC3 3.8 (on deck)
44

55
### New features
6-
6+
- Add capabilities to do inference on parameters in a differential equation with `DifferentialEquation`. See [#3590](https://github.com/pymc-devs/pymc3/pull/3590).
77
- Distinguish between `Data` and `Deterministic` variables when graphing models with graphviz. PR [#3491](https://github.com/pymc-devs/pymc3/pull/3491).
88
- Sequential Monte Carlo - Approximate Bayesian Computation step method is now available. The implementation is in an experimental stage and will be further improved.
99
- Added `Matern12` covariance function for Gaussian processes. This is the Matern kernel with nu=1/2.

docs/source/notebooks/ODE_API_parameter_estimation.ipynb

Lines changed: 570 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/notebooks/ODE_parameter_estimation.ipynb

Lines changed: 99 additions & 197 deletions
Large diffs are not rendered by default.

docs/source/notebooks/table_of_contents_examples.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,6 @@ Gallery.contents = {
5353
"normalizing_flows_overview": "Variational Inference",
5454
"gaussian-mixture-model-advi": "Variational Inference",
5555
"GLM-hierarchical-advi-minibatch": "Variational Inference",
56-
"ODE_parameter_estimation": "Inference in ODE models"
56+
"ODE_parameter_estimation": "Inference in ODE models",
57+
"ODE_API_parameter_estimation": "Inference in ODE models with DifferentialEquation"
5758
}

pymc3/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .math import logaddexp, logsumexp, logit, invlogit, expand_packed_triangular, probit, invprobit
99
from .model import *
1010
from .model_graph import model_to_graphviz
11+
from . import ode
1112
from .stats import *
1213
from .sampling import *
1314
from .step_methods import *

pymc3/ode/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import utils
2+
from .ode import DifferentialEquation

pymc3/ode/ode.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import numpy as np
2+
import scipy
3+
import theano
4+
import theano.tensor as tt
5+
from ..ode.utils import augment_system, ODEGradop
6+
7+
8+
class DifferentialEquation(theano.Op):
9+
"""
10+
Specify an ordinary differential equation
11+
12+
.. math::
13+
\dfrac{dy}{dt} = f(y,t,p) \quad y(t_0) = y_0
14+
15+
Parameters
16+
----------
17+
18+
func : callable
19+
Function specifying the differential equation
20+
t0 : float
21+
Time corresponding to the initial condition
22+
times : array
23+
Array of times at which to evaluate the solution of the differential equation.
24+
n_states : int
25+
Dimension of the differential equation. For scalar differential equations, n_states=1.
26+
For vector valued differential equations, n_states = number of differential equations in the system.
27+
n_odeparams : int
28+
Number of parameters in the differential equation.
29+
30+
.. code-block:: python
31+
32+
def odefunc(y, t, p):
33+
#Logistic differential equation
34+
return p[0] * y[0] * (1 - y[0])
35+
36+
times = np.arange(0.5, 5, 0.5)
37+
38+
ode_model = DifferentialEquation(func=odefunc, t0=0, times=times, n_states=1, n_odeparams=1)
39+
"""
40+
41+
__props__ = ("func", "t0", "times", "n_states", "n_odeparams")
42+
43+
def __init__(self, func, times, n_states, n_odeparams, t0=0):
44+
if not callable(func):
45+
raise ValueError("Argument func must be callable.")
46+
if n_states < 1:
47+
raise ValueError("Argument n_states must be at least 1.")
48+
if n_odeparams <= 0:
49+
raise ValueError("Argument n_odeparams must be positive.")
50+
51+
# Public
52+
self.func = func
53+
self.t0 = t0
54+
self.times = tuple(times)
55+
self.n_states = n_states
56+
self.n_odeparams = n_odeparams
57+
58+
# Private
59+
self._n = n_states
60+
self._m = n_odeparams + n_states
61+
62+
self._augmented_times = np.insert(times, 0, t0)
63+
self._augmented_func = augment_system(func, self._n, self._m)
64+
self._sens_ic = self._make_sens_ic()
65+
66+
self._cached_y = None
67+
self._cached_sens = None
68+
self._cached_parameters = None
69+
70+
self._grad_op = ODEGradop(self._numpy_vsp)
71+
72+
def _make_sens_ic(self):
73+
"""
74+
The sensitivity matrix will always have consistent form.
75+
If the first n_odeparams entries of the parameters vector in the simulate call
76+
correspond to ode paramaters, then the first n_odeparams columns in
77+
the sensitivity matrix will be 0
78+
79+
If the last n_states entries of the paramters vector in the simulate call
80+
correspond to initial conditions of the system,
81+
then the last n_states columns of the sensitivity matrix should form
82+
an identity matrix
83+
"""
84+
85+
# Initialize the sensitivity matrix to be 0 everywhere
86+
sens_matrix = np.zeros((self._n, self._m))
87+
88+
# Slip in the identity matrix in the appropirate place
89+
sens_matrix[:, -self.n_states :] = np.eye(self.n_states)
90+
91+
# We need the sensitivity matrix to be a vector (see augmented_function)
92+
# Ravel and return
93+
dydp = sens_matrix.ravel()
94+
95+
return dydp
96+
97+
def _system(self, Y, t, p):
98+
"""This is the function that will be passed to odeint. Solves both ODE and sensitivities
99+
100+
"""
101+
102+
dydt, ddt_dydp = self._augmented_func(Y[: self._n], t, p, Y[self._n :])
103+
derivatives = np.concatenate([dydt, ddt_dydp])
104+
return derivatives
105+
106+
def _simulate(self, parameters):
107+
# Initial condition comprised of state initial conditions and raveled
108+
# sensitivity matrix
109+
y0 = np.concatenate([parameters[self.n_odeparams :], self._sens_ic])
110+
111+
# perform the integration
112+
sol = scipy.integrate.odeint(
113+
func=self._system, y0=y0, t=self._augmented_times, args=(parameters,)
114+
)
115+
# The solution
116+
y = sol[1:, : self.n_states]
117+
118+
# The sensitivities, reshaped to be a sequence of matrices
119+
sens = sol[1:, self.n_states :].reshape(len(self.times), self._n, self._m)
120+
121+
return y, sens
122+
123+
def _cached_simulate(self, parameters):
124+
if np.array_equal(np.array(parameters), self._cached_parameters):
125+
126+
return self._cached_y, self._cached_sens
127+
128+
return self._simulate(np.array(parameters))
129+
130+
def _state(self, parameters):
131+
y, sens = self._cached_simulate(np.array(parameters))
132+
self._cached_y, self._cached_sens, self._cached_parameters = y, sens, parameters
133+
return y.ravel()
134+
135+
def _numpy_vsp(self, parameters, g):
136+
_, sens = self._cached_simulate(np.array(parameters))
137+
138+
# Each element of sens is an nxm sensitivity matrix
139+
# There is one sensitivity matrix per time step, making sens a (len(times), n_states, len(parameter))
140+
# dimensional array. Reshaping the sens array in this way is like stacking each of the elements of sens on top
141+
# of one another.
142+
numpy_sens = sens.reshape((self.n_states * len(self.times), len(parameters)))
143+
# The dot product here is equivalent to np.einsum('ijk,jk', sens, g)
144+
# if sens was not reshaped and if g had the same shape as yobs
145+
return numpy_sens.T.dot(g)
146+
147+
def make_node(self, odeparams, y0):
148+
if len(odeparams) != self.n_odeparams:
149+
raise ValueError(
150+
"odeparams has too many or too few parameters. Expected {a} parameter(s) but got {b}".format(
151+
a=self.n_odeparams, b=len(odeparams)
152+
)
153+
)
154+
if len(y0) != self.n_states:
155+
raise ValueError(
156+
"y0 has too many or too few parameters. Expected {a} parameter(s) but got {b}".format(
157+
a=self.n_states, b=len(y0)
158+
)
159+
)
160+
161+
if np.ndim(odeparams) > 1:
162+
odeparams = np.ravel(odeparams)
163+
if np.ndim(y0) > 1:
164+
y0 = np.ravel(y0)
165+
166+
odeparams = tt.as_tensor_variable(odeparams)
167+
y0 = tt.as_tensor_variable(y0)
168+
parameters = tt.concatenate([odeparams, y0])
169+
return theano.Apply(self, [parameters], [parameters.type()])
170+
171+
def perform(self, node, inputs_storage, output_storage):
172+
parameters = inputs_storage[0]
173+
out = output_storage[0]
174+
# get the numerical solution of ODE states
175+
out[0] = self._state(parameters)
176+
177+
def grad(self, inputs, output_grads):
178+
x = inputs[0]
179+
g = output_grads[0]
180+
# pass the VSP when asked for gradient
181+
grad_op_apply = self._grad_op(x, g)
182+
183+
return [grad_op_apply]

pymc3/ode/utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import numpy as np
2+
import theano
3+
import theano.tensor as tt
4+
5+
6+
def augment_system(ode_func, n, m):
7+
"""
8+
Function to create augmented system.
9+
10+
Take a function which specifies a set of differential equations and return
11+
a compiled function which allows for computation of gradients of the
12+
differential equation's solition with repsect to the parameters.
13+
14+
Args:
15+
ode_func (function): Differential equation. Returns array-like
16+
n: Number of rows of the sensitivity matrix
17+
m: Number of columns of the sensitivity matrix
18+
19+
Returns:
20+
system (function): Augemted system of differential equations.
21+
22+
"""
23+
24+
# Present state of the system
25+
t_y = tt.vector("y", dtype=theano.config.floatX)
26+
t_y.tag.test_value = np.zeros((n,))
27+
# Parameter(s). Should be vector to allow for generaliztion to multiparameter
28+
# systems of ODEs. Is m dimensional because it includes all ode parameters as well as initical conditions
29+
t_p = tt.vector("p", dtype=theano.config.floatX)
30+
t_p.tag.test_value = np.zeros((m,))
31+
# Time. Allow for non-automonous systems of ODEs to be analyzed
32+
t_t = tt.scalar("t", dtype=theano.config.floatX)
33+
t_t.tag.test_value = 2.459
34+
35+
# Present state of the gradients:
36+
# Will always be 0 unless the parameter is the inital condition
37+
# Entry i,j is partial of y[i] wrt to p[j]
38+
dydp_vec = tt.vector("dydp", dtype=theano.config.floatX)
39+
dydp_vec.tag.test_value = np.zeros(n * m)
40+
41+
dydp = dydp_vec.reshape((n, m))
42+
43+
# Stack the results of the ode_func
44+
f_tensor = tt.stack(ode_func(t_y, t_t, t_p))
45+
46+
# Now compute gradients
47+
J = tt.jacobian(f_tensor, t_y)
48+
49+
Jdfdy = tt.dot(J, dydp)
50+
51+
grad_f = tt.jacobian(f_tensor, t_p)
52+
53+
# This is the time derivative of dydp
54+
ddt_dydp = (Jdfdy + grad_f).flatten()
55+
56+
system = theano.function(
57+
inputs=[t_y, t_t, t_p, dydp_vec],
58+
outputs=[f_tensor, ddt_dydp],
59+
on_unused_input="ignore",
60+
)
61+
62+
return system
63+
64+
65+
class ODEGradop(theano.Op):
66+
def __init__(self, numpy_vsp):
67+
self._numpy_vsp = numpy_vsp
68+
69+
def make_node(self, x, g):
70+
71+
x = theano.tensor.as_tensor_variable(x)
72+
g = theano.tensor.as_tensor_variable(g)
73+
node = theano.Apply(self, [x, g], [g.type()])
74+
return node
75+
76+
def perform(self, node, inputs_storage, output_storage):
77+
x = inputs_storage[0]
78+
g = inputs_storage[1]
79+
out = output_storage[0]
80+
out[0] = self._numpy_vsp(x, g) # get the numerical VSP

0 commit comments

Comments
 (0)