Skip to content

Commit

Permalink
Merge pull request #4 from LukeLabrie/master
Browse files Browse the repository at this point in the history
find time at which derivative of spline reaches a given value
  • Loading branch information
Wrzlprmft authored Jun 21, 2024
2 parents 1fc965c + 59498cf commit ec343ec
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 34 deletions.
40 changes: 26 additions & 14 deletions chspy/_chspy.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def extrema_from_anchors(anchors,beginning=None,end=None,target=None):

return extrema

def solve_from_anchors(anchors,i,value,beginning=None,end=None):
def solve_from_anchors(anchors,i,value,beginning=None,end=None,solve_derivative=False):
"""
Finds the times at which a component of the Hermite interpolant for the anchors assumes a given value and the derivatives at those points (allowing to distinguish upwards and downwards threshold crossings).
Expand All @@ -293,29 +293,38 @@ def solve_from_anchors(anchors,i,value,beginning=None,end=None):
Beginning of the time interval for which positions are returned. If `None`, the time of the first anchor is used.
end : float or `None`
End of the time interval for which positions are returned. If `None`, the time of the last anchor is used.
solve_derivative : bool
Whether to find where the derivative (instead of the state) assumes a given value.
Returns
-------
positions : list of pairs of floats
Each pair consists of a time where `value` is assumed and the derivative (of `component`) at that time.
"""

q = (anchors[1].time-anchors[0].time)
retransform = lambda x: q*x+anchors[0].time
a = anchors[0].state[i]
b = anchors[0].diff[i] * q
c = anchors[1].state[i]
d = anchors[1].diff[i] * q
p0 = anchors[0].state[i]
m0 = anchors[0].diff[i] * q
p1 = anchors[1].state[i]
m1 = anchors[1].diff[i] * q

left_x = 0 if beginning is None else (beginning-anchors[0].time)/q
right_x = 1 if end is None else (end -anchors[0].time)/q

candidates = np.roots([
2*a + b - 2*c + d,
-3*a - 2*b + 3*c - d,
b,
a - value,
])
if solve_derivative:
candidates = np.roots([
3 * (2*p0 + m0 - 2*p1 + m1),
2 * (-3*p0 - 2*m0 + 3*p1 - m1),
m0 - value*q,
])
else:
candidates = np.roots([
2*p0 + m0 - 2*p1 + m1,
-3*p0 - 2*m0 + 3*p1 - m1,
m0,
p0 - value,
])

solutions = sorted(
retransform(candidate.real)
Expand Down Expand Up @@ -706,7 +715,7 @@ def extrema(self,beginning=None,end=None):

return extrema

def solve(self,i,value,beginning=None,end=None):
def solve(self,i,value,beginning=None,end=None,solve_derivative=False):
"""
Finds the times at which a component of the spline assumes a given value and the derivatives at those points (allowing to distinguish upwards and downwards threshold crossings). This will not work well if the spline is constantly at the given value for some interval.
Expand All @@ -720,6 +729,8 @@ def solve(self,i,value,beginning=None,end=None):
Beginning of the time interval for which solutions are returned. If `None`, the time of the first anchor is used.
end : float or `None`
End of the time interval for which solutions are returned. If `None`, the time of the last anchor is used.
solve_derivative : bool
Whether to find where the derivative (instead of the state) assumes a given value.
Returns
-------
Expand All @@ -740,13 +751,14 @@ def solve(self,i,value,beginning=None,end=None):
for j in range(self.last_index_before(beginning),len(self)-1):
if self[j].time>end:
break

new_sols = solve_from_anchors(
anchors = ( self[j], self[j+1] ),
i = i,
value = value,
beginning = max( beginning, self[j ].time ),
end = min( end , self[j+1].time ),
solve_derivative = solve_derivative,
)

if sols and new_sols and sols[-1][0]==new_sols[0][0]:
Expand Down
48 changes: 28 additions & 20 deletions tests/test_hermite_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from chspy._chspy import rel_dist

import symengine
import sympy
import numpy as np
from numpy.testing import assert_allclose
import unittest
Expand Down Expand Up @@ -269,26 +270,33 @@ def test_multiple_anchors(self):

class TestSolving(unittest.TestCase):
def test_random_function(self):
roots = np.sort(np.random.normal(size=5))
value = np.random.normal()
t = symengine.Symbol("t")
function = np.prod([t-root for root in roots]) + value

i = 1
spline = CubicHermiteSpline.from_func(
[10,function,10],
times_of_interest = ( min(roots)-0.01, max(roots)+0.01 ),
max_anchors = 1000,
tol = 7,
)

solutions = spline.solve(i=i,value=value)
sol_times = [ sol[0] for sol in solutions ]
assert_allclose( spline.get_state(sol_times)[:,i], value )
assert_allclose( [sol[0] for sol in solutions], roots, atol=1e-3 )
for time,diff in solutions:
true_diff = float(function.diff(t).subs({t:time}))
self.assertAlmostEqual( true_diff, diff, places=5 )
for solve_derivative in [False,True]:
roots = np.sort(np.random.normal(size=5))
value = np.random.normal()
t = symengine.Symbol("t")
function = np.prod([t-root for root in roots]) + value
if solve_derivative:
function = sympy.integrate(function,[t]) + np.random.random()

i = 1
spline = CubicHermiteSpline.from_func(
[10,function,10],
times_of_interest = ( min(roots)-0.01, max(roots)+0.01 ),
max_anchors = 1000,
tol = 7,
)

solutions = spline.solve(i=i,value=value,solve_derivative=solve_derivative)
sol_times = [ sol[0] for sol in solutions ]
assert_allclose( sol_times, roots, atol=1e-3 )
if solve_derivative:
for time,diff in solutions:
self.assertAlmostEqual( value, diff, places=5 )
else:
assert_allclose( spline.get_state(sol_times)[:,i], value )
for time,diff in solutions:
true_diff = float(function.diff(t).subs({t:time}))
self.assertAlmostEqual( true_diff, diff, places=5 )

class TimeSeriesTest(unittest.TestCase):
def test_comparison(self):
Expand Down

0 comments on commit ec343ec

Please sign in to comment.