-
-
Notifications
You must be signed in to change notification settings - Fork 141
/
Copy pathjump_step_timing.py
128 lines (91 loc) · 2.97 KB
/
jump_step_timing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from warnings import simplefilter
simplefilter(action="ignore", category=FutureWarning)
import timeit
from functools import partial
import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
from old_pid_controller import OldPIDController
t0 = 0
t1 = 5
dt0 = 0.5
y0 = 1.0
drift = diffrax.ODETerm(lambda t, y, args: -0.2 * y)
def diffusion_vf(t, y, args):
return jnp.ones((), dtype=y.dtype)
def get_terms(key):
bm = diffrax.VirtualBrownianTree(t0, t1, 2**-5, (), key)
diffusion = diffrax.ControlTerm(diffusion_vf, bm)
return diffrax.MultiTerm(drift, diffusion)
solver = diffrax.Heun()
step_ts = jnp.linspace(t0, t1, 129, endpoint=True)
pid_controller = diffrax.PIDController(
rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7
)
new_controller = diffrax.JumpStepWrapper(
pid_controller,
step_ts=step_ts,
rejected_step_buffer_len=0,
)
old_controller = OldPIDController(
rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7, step_ts=step_ts
)
@eqx.filter_jit
@partial(jax.vmap, in_axes=(0, None))
def solve(key, controller):
term = get_terms(key)
return diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
stepsize_controller=controller,
saveat=diffrax.SaveAt(ts=step_ts),
)
num_samples = 100
keys = jr.split(jr.PRNGKey(0), num_samples)
# NEW CONTROLLER
@jax.jit
@eqx.debug.assert_max_traces(max_traces=1)
def time_new_controller_fun():
sols = solve(keys, new_controller)
assert sols.ys is not None
assert sols.ys.shape == (num_samples, len(step_ts))
return sols.ys
def time_new_controller():
jax.block_until_ready(time_new_controller_fun())
# OLD CONTROLLER
@jax.jit
@eqx.debug.assert_max_traces(max_traces=1)
def time_old_controller_fun():
sols = solve(keys, old_controller)
assert sols.ys is not None
assert sols.ys.shape == (num_samples, len(step_ts))
return sols.ys
def time_old_controller():
jax.block_until_ready(time_old_controller_fun())
time_new = min(timeit.repeat(time_new_controller, number=3, repeat=20))
time_old = min(timeit.repeat(time_old_controller, number=3, repeat=20))
print(f"New controller: {time_new:.5} s, Old controller: {time_old:.5} s")
# How expensive is revisiting rejected steps?
new_revisiting_controller = diffrax.JumpStepWrapper(
pid_controller,
step_ts=step_ts,
rejected_step_buffer_len=10,
)
def time_revisiting_controller_fun():
sols = solve(keys, new_revisiting_controller)
assert sols.ys is not None
assert sols.ys.shape == (num_samples, len(step_ts))
return sols.ys
def time_revisiting_controller():
jax.block_until_ready(time_revisiting_controller_fun())
time_revisiting = min(timeit.repeat(time_revisiting_controller, number=3, repeat=20))
print(f"Revisiting controller: {time_revisiting:.5} s")
# ======= RESULTS =======
# New controller: 0.22829 s, Old controller: 0.31039 s
# Revisiting controller: 0.23212 s