From 16b5f956c3a1bea0a9234b31be74cd86c471f075 Mon Sep 17 00:00:00 2001 From: Luc Grosheintz Date: Wed, 13 Nov 2024 14:01:59 +0100 Subject: [PATCH] Fix for finite differences. (#1558) The previous implementation would apply the transformation to children of the node in the AST. Hence, if the root was `sp.Derivative` it would fail. This commit adds a fix and the tests. --- python/nmodl/ode.py | 6 +++++- test/unit/ode/test_ode.py | 21 +++++++++++++++++++ test/usecases/solve/finite_difference.mod | 7 +++++-- test/usecases/solve/test_finite_difference.py | 5 +++++ 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index c1c907eae..f34c9d026 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -366,7 +366,11 @@ def transform_expression(expr, transform): if not expr.args: return expr - args = (transform_expression(transform(arg), transform) for arg in expr.args) + transformed_expr = transform(expr) + if transformed_expr is not expr: + return transformed_expr + + args = (transform_expression(arg, transform) for arg in expr.args) return expr.func(*args) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 82e0358d2..8be63e02e 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 from nmodl.ode import differentiate2c, integrate2c, make_symbol +from nmodl.ode import transform_expression, discretize_derivative import pytest import sympy as sp @@ -186,3 +187,23 @@ def test_integrate2c(): assert _equivalent( integrate2c(f"x'={eq}", "dt", var_list, use_pade_approx=True), f"x = {sol}" ) + + +def test_finite_difference(): + df_dx = "(f(x + x_delta_/2) - f(x - x_delta_/2))/x_delta_" + dg_dx = "(g(x + x_delta_/2) - g(x - x_delta_/2))/x_delta_" + + test_cases = [ + ("f(x)", df_dx), + ("a*f(x)", f"a*{df_dx}"), + ("a*f(x)*g(x)", f"a*({df_dx}*g(x) + f(x)*{dg_dx})"), + ("a*f(x) + b*g(x)", f"a*{df_dx} + b*{dg_dx}"), + ] + vars = ["a", "x", "x_delta_"] + + for expr, expected in test_cases: + expr = sp.diff(sp.sympify(expr), "x") + actual = transform_expression(expr, discretize_derivative) + msg = f"'{actual}' =!= '{expected}'" + + assert _equivalent(str(actual), expected, vars=vars), msg diff --git a/test/usecases/solve/finite_difference.mod b/test/usecases/solve/finite_difference.mod index 2c0f94e86..1a1bab103 100644 --- a/test/usecases/solve/finite_difference.mod +++ b/test/usecases/solve/finite_difference.mod @@ -10,10 +10,12 @@ ASSIGNED { STATE { x + z } INITIAL { x = 42.0 + z = 21.0 a = 0.1 } @@ -22,9 +24,10 @@ BREAKPOINT { } DERIVATIVE dX { - x' = -f(x) + x' = f(x) + z' = 2.0*f(z) } FUNCTION f(x) { - f = a*x + f = -a*x } diff --git a/test/usecases/solve/test_finite_difference.py b/test/usecases/solve/test_finite_difference.py index b0c22af6b..e4b8f9170 100644 --- a/test/usecases/solve/test_finite_difference.py +++ b/test/usecases/solve/test_finite_difference.py @@ -11,6 +11,7 @@ def test_finite_difference(): s.nseg = nseg x_hoc = h.Vector().record(s(0.5)._ref_x_finite_difference) + z_hoc = h.Vector().record(s(0.5)._ref_z_finite_difference) t_hoc = h.Vector().record(h._ref_t) h.stdinit() @@ -19,12 +20,16 @@ def test_finite_difference(): h.run() x = np.array(x_hoc.as_numpy()) + z = np.array(z_hoc.as_numpy()) t = np.array(t_hoc.as_numpy()) a = h.a_finite_difference x_exact = 42.0 * np.exp(-a * t) np.testing.assert_allclose(x, x_exact, rtol=1e-4) + z_exact = 21.0 * np.exp(-2.0 * a * t) + np.testing.assert_allclose(z, z_exact, rtol=1e-4) + if __name__ == "__main__": test_finite_difference()