Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to pass user_data argument of f to the solver #3

Closed
papamarkou opened this issue Jun 17, 2013 · 7 comments
Closed

How to pass user_data argument of f to the solver #3

papamarkou opened this issue Jun 17, 2013 · 7 comments

Comments

@papamarkou
Copy link
Contributor

Hi Tom, perhaps I should not report this as an issue, since it is more of a question as to how to use one of your package's functionalities. I want to pass an extra argument to f, which you call user_data in examples/cvode_Roberts_dns.jl (and is of type Ptr{Void}). I will use the CVODE solver and I don't know how this user_data argument of f will be passed to CVODE ultimately - does the package support this functionality? As an example, I simply simulated the ODE system below via Sundials.ode (and circumvented the problem by wrapping f2 in f). In this example, the parameters p = [a, b, c] play the role of user_data. Any hints as to whether it is possible to pass them to CVODE via your package will be valuable. Here is the simple example (Fitzhugh-Nagumo ODE system, consisting of 2 ODEs):

# Simulate data from the Fitzhugh-Nagumo ODEs
# d/dt(V) = c*(R+V-V^3/3)
# d/dt(R) = -(V-a+R*b)/c

using Sundials

function f2(t, y, ydot, p)    
    ydot[1] = p[3]*(y[2]+y[1]-y[1]^3/3)
    ydot[2] = -(y[1]-p[1]+y[2]*p[2])/p[3]
end

# Time t
t = [i for i in 0.:10]

# Initial conditions y0 = [V, R]
y0 = [-1., 1.];

# Parameters p = [a, b, c]
p = [0.2, 0.2, 3];

f(t, y, ydot) = f2(t, y, ydot, p)

data = Sundials.ode(f, y0, t)
@tshort
Copy link
Contributor

tshort commented Jun 17, 2013

Hi Theo, you have the most concise way to do that now. To eliminate your wrapper function, it could be possible to extend Sundials.ode in src/Sundials.jl to add another user argument.

I'm going to leave this issue open as a reminder, but I don't have time to look at it until the end of summer.

@papamarkou
Copy link
Contributor Author

Thanks Tom, that's true, the package is already operational and I will use it as it is - if you make any changes at some point in the future, they are welcome.

@papamarkou
Copy link
Contributor Author

Tom, I solved it, I eliminated the wrapper function! Here is the code for simulating the Fitzhugh-Nagumo ODEs using CVODE, by passing the 3 parameters of the ODE system as a vector to cfunction():

# Simulate data from the Fitzhugh-Nagumo ODEs
# d/dt(V) = c*(R+V-V^3/3)
# d/dt(R) = -(V-a+R*b)/c

using Sundials

# ODE equations
function f(t, y, ydot, p) 
    y = Sundials.asarray(y)
    ydot = Sundials.asarray(ydot)

    ydot[1] = p[3]*(y[2]+y[1]-y[1]^3/3)
    ydot[2] = -(y[1]-p[1]+y[2]*p[2])/p[3]

    return int32(0)
end

# Initial condition y0 = [V, R]
y0 = [-10., 1.];
nOdes = length(y0);

# Parameters p = [a, b, c]
p = [0.2, 0.2, 3];

# Time t
timeStart = 0; timeEnd = 25; timeStep=1.; 
t = [i for i in timeStart:timeStep:timeEnd];

# Tolerance
relTol = 1e-4;
absTol = [1e-8, 1e-14, 1e-6];

# Set up CVODE solver
cvode_mem = Sundials.CVodeCreate(Sundials.CV_BDF, Sundials.CV_NEWTON);
flag = Sundials.CVodeInit(cvode_mem, cfunction(f, Int32, (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Vector{Float64})), t[1], Sundials.nvector(y0));
flag = Sundials.CVodeSetUserData(cvode_mem, p);
flag = Sundials.CVodeSVtolerances(cvode_mem, relTol, absTol);
flag = Sundials.CVDense(cvode_mem, nOdes);

# Solve ODE system numerically
y = Array(Float64, length(t)-1, length(y0));

for i in 2:length(t)
  flag = Sundials.CVode(cvode_mem, t[i], y0, [t[1]], Sundials.CV_NORMAL)

  if flag != Sundials.CV_SUCCESS
    println("SUNDIALS_ERROR: CVODE failed with flag = ", flag)
    break
  end

  y[i-1, :] = y0
  println("T = ", i-1, ", Y = ", y0)
end

If one wants to use your ode() function directly, they may have to pass a cell array as the user_data containing both the user-defined function and the parameters; I copy-pasted the above code as a starting point so that you have it when you will make changes to your package (I would be happy to contribute).

@papamarkou
Copy link
Contributor Author

P.S. Here is the generic solution based on a cell array which passes both the function and its parameters to odefun():

# Simulate data from the Fitzhugh-Nagumo ODEs
# d/dt(V) = c*(R+V-V^3/3)
# d/dt(R) = -(V-a+R*b)/c

using Sundials

# ODE equations
function f(t, y, ydot, p)
  ydot[1] = p[3]*(y[2]+y[1]-y[1]^3/3)
  ydot[2] = -(y[1]-p[1]+y[2]*p[2])/p[3]  
end

# Generic function corresponding to odefun() in src/Sundials.jl
# user_data[1] = f and user_data[2] = p
function odeFun(t, y, ydot, user_data) 
  y = Sundials.asarray(y)
  ydot = Sundials.asarray(ydot)

  user_data[1](t, y, ydot, user_data[2])

  return int32(0)
end

# Initial condition y0 = [V, R]
y0 = [-10., 1.];
nOdes = length(y0);

# Parameters p = [a, b, c]
p = [0.2, 0.2, 3];

# Time t
timeStart = 0; timeEnd = 25; timeStep=1.; 
t = [i for i in timeStart:timeStep:timeEnd];

# Tolerance
relTol = 1e-4;
absTol = [1e-8, 1e-14, 1e-6];

# Set up CVODE solver
cvode_mem = Sundials.CVodeCreate(Sundials.CV_BDF, Sundials.CV_NEWTON);
flag = Sundials.CVodeInit(cvode_mem, cfunction(odeFun, Int32, (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Array{Any, 1})), t[1], Sundials.nvector(y0));
flag = Sundials.CVodeSetUserData(cvode_mem, {f, p});
flag = Sundials.CVodeSVtolerances(cvode_mem, relTol, absTol);
flag = Sundials.CVDense(cvode_mem, nOdes);

# Solve ODE system numerically
y = Array(Float64, length(t)-1, length(y0));

for i in 2:length(t)
  flag = Sundials.CVode(cvode_mem, t[i], y0, [t[1]], Sundials.CV_NORMAL)

  if flag != Sundials.CV_SUCCESS
    println("SUNDIALS_ERROR: CVODE failed with flag = ", flag)
    break
  end

  y[i-1, :] = y0
  println("T = ", i-1, ", Y = ", y0)
end

@gasagna
Copy link

gasagna commented Sep 1, 2014

Nice solution to the problem. However, is this planned to be included by default?

Thanks

Davide

@dextorious
Copy link

I actually found the wrapper function to be a more elegant approach in the sense that it lets you keep the high level interface and write more concise, readable code. While the ideal solution would obviously be to extend the high level interface itself, I cannot help but wonder whether the wrapper approach has real performance penalties? In other words, is it worth manually applying the proper solution to each system I'm solving or is the wrapper approach fine for now?

@ChrisRackauckas
Copy link
Member

This has been pretty much been made obsolete by fast closures in v0.5:

g = (t,u,du) -> f(t,u,du,p)

or using ParameterizedFunctions.jl (or wrapper functions, or the user_data).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants