Skip to content

Repair ode API after refactor broke it #3684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 19, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more detailed docstrings
  • Loading branch information
michaelosthege committed Nov 18, 2019
commit 83085461dbcbc5d3814e4f29ce7200b6f254b3ba
8 changes: 6 additions & 2 deletions pymc3/ode/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DifferentialEquation(theano.Op):
----------

func : callable
Function specifying the differential equation
Function specifying the differential equation. Must take arguments y (n_states,), t (scalar), p (n_theta,)
times : array
Array of times at which to evaluate the solution of the differential equation.
n_states : int
Expand Down Expand Up @@ -101,11 +101,15 @@ def _make_sens_ic(self):
# We need the sensitivity matrix to be a vector (see augmented_function)
# Ravel and return
dydp = sens_matrix.ravel()

return dydp

def _system(self, Y, t, p):
"""This is the function that will be passed to odeint. Solves both ODE and sensitivities.

Args:
Y: augmented state vector (n_states + n_states + n_theta)
t: current time
p: parameter vector (y0, theta)
"""
dydt, ddt_dydp = self._augmented_func(Y[:self.n_states], t, p, Y[self.n_states:])
derivatives = np.concatenate([dydt, ddt_dydp])
Expand Down