diff --git a/examples/sde_example.ipynb b/examples/sde_example.ipynb new file mode 100644 index 00000000..76078c20 --- /dev/null +++ b/examples/sde_example.ipynb @@ -0,0 +1,889 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-07T13:23:27.783031Z", + "start_time": "2024-06-07T13:23:26.687282Z" + }, + "collapsed": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: JAX_PLATFORM_NAME=cuda\n" + ] + } + ], + "source": [ + "%env JAX_PLATFORM_NAME=cuda\n", + "\n", + "\n", + "import diffrax\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import matplotlib.pyplot as plt\n", + "from diffrax import (\n", + " SpaceTimeLevyArea,\n", + " SPaRK,\n", + ")\n", + "\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "jnp.set_printoptions(precision=4, suppress=True)" + ] + }, + { + "cell_type": "markdown", + "id": "86d4e8b062a81d7e", + "metadata": {}, + "source": [ + "# Simulating SDEs in Diffrax\n", + "\n", + "We will be simulating a Stratonovich SDE of the form:\n", + "$$\n", + " dY_t = f(Y_t, t) dt + g(Y_t, t) \\circ dW_t, \n", + "$$\n", + "where $t \\in [0, T]$, $Y_t \\in \\mathbb{R}^e$, and $W$ is a standard Brownian motion on $\\mathbb{R}^d$. We refer to $f: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^e$ as the drift vector field and $g: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^{e \\times d}$ is the diffusion matrix field. The Stratonovich integral is denoted by $\\circ$.\n", + "\n", + "Our SDE will have the following drift and diffusion terms:\n", + "\\begin{align*}\n", + " f(Y_t, t) &= \\alpha - \\beta Y_t, \\\\\n", + " g(Y_t, t) &= \\gamma \\begin{bmatrix} \\Vert Y_t \\Vert_2 & 0 \\\\ 0 & Y_{t, 1} \\\\ 0 & 10t \\end{bmatrix},\n", + "\\end{align*}\n", + "where $\\alpha, \\gamma \\in \\mathbb{R}^3$ and $\\beta \\in \\mathbb{R}_{\\geq 0}$ are some parameters.\n", + "\n", + "Let's write the SDE in the form that Diffrax expects:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ba23e9cc0370fbac", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-03T10:59:07.293401Z", + "start_time": "2024-06-03T10:59:07.223425Z" + } + }, + "outputs": [], + "source": [ + "# Drift VF (e = 3)\n", + "def f(t, y, args):\n", + " alpha, beta, gamma = args\n", + " beta = jnp.abs(beta)\n", + " assert alpha.shape == (3,)\n", + " return jnp.array(alpha - beta * y, dtype=y.dtype)\n", + "\n", + "\n", + "# Diffusion matrix field (e = 3, d = 2)\n", + "def g(t, y, args):\n", + " alpha, beta, gamma = args\n", + " assert gamma.shape == y.shape == (3,)\n", + " gamma = jnp.reshape(gamma, (3, 1))\n", + " out = gamma * jnp.array(\n", + " [[jnp.sqrt(jnp.sum(y**2)), 0.0], [0.0, 3 * y[0]], [0.0, 20 * t]], dtype=y.dtype\n", + " )\n", + " return out\n", + "\n", + "\n", + "# Initial condition\n", + "y0 = jnp.array([1.0, 1.0, 1.0])\n", + "\n", + "# Args\n", + "alpha = 0.5 * jnp.ones((3,))\n", + "beta = 1.0\n", + "gamma = jnp.ones((3,))\n", + "args = (alpha, beta, gamma)\n", + "\n", + "# Time domain\n", + "t0 = 0.0\n", + "t1 = 2.0" + ] + }, + { + "cell_type": "markdown", + "id": "ef2ff90865907b7d", + "metadata": {}, + "source": [ + "## Brownian motion and its Levy area\n", + "\n", + "Different solvers require different information about the Brownian motion. For example, the `SPaRK` solver requires access to the space-time Levy area of the Brownian motion. The required Levy area for each solver is documented in the table at the end of this notebook, or can be checked via `solver.minimal_levy_area`.\n", + " \n", + "We will use the `VirtualBrownianTree` class to generate the Brownian motion and its Levy area." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4110735158215acc", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-03T10:59:08.656002Z", + "start_time": "2024-06-03T10:59:08.593425Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Minimal levy area for SPaRK: AbstractSpaceTimeLevyArea.\n" + ] + } + ], + "source": [ + "# check minimal levy area\n", + "solver = SPaRK()\n", + "print(f\"Minimal levy area for SPaRK: {solver.minimal_levy_area.__name__}.\")\n", + "\n", + "# Brownian motion\n", + "key = jr.key(0)\n", + "bm_tol = 2**-13\n", + "bm_shape = (2,)\n", + "bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, bm_shape, key, levy_area=SpaceTimeLevyArea\n", + ")\n", + "\n", + "# Defining the terms of the SDE\n", + "ode_term = diffrax.ODETerm(f)\n", + "diffusion_term = diffrax.ControlTerm(g, bm) # Note that the BM is baked into the term\n", + "terms = diffrax.MultiTerm(ode_term, diffusion_term)" + ] + }, + { + "cell_type": "markdown", + "id": "e71db03c5257bd46", + "metadata": {}, + "source": [ + "### Using `diffrax.diffeqsolve` to solve the SDE\n", + "\n", + "We will first use constant steps of size $h = 2^{-9}$ to solve the SDE. It is very important to have $h > \\mathtt{bm_tol}$, where $\\mathtt{bm_tol}$ is the tolerance of the Brownian motion.\n", + "\n", + " We will use the SPaRK method to solve the SDE. ShARK is a stochastic Runge-Kutta method that requires access to space-time Levy area." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1841a462895a32f", + "metadata": {}, + "outputs": [], + "source": [ + "dt0 = 2**-9" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a969e1b9bd9f09", + "metadata": {}, + "outputs": [], + "source": [ + "sol = diffrax.diffeqsolve(\n", + " terms, SPaRK(), t0, t1, dt0, y0, args, saveat=diffrax.SaveAt(steps=True)\n", + ")\n", + "\n", + "# Plotting the solution on ax1 and the BM on ax2\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol.ts, sol.ys[:, 0], label=\"Y_1\")\n", + "ax1.plot(sol.ts, sol.ys[:, 1], label=\"Y_2\")\n", + "ax1.plot(sol.ts, sol.ys[:, 2], label=\"Y_3\")\n", + "ax1.set_title(\"SDE solution\")\n", + "ax1.legend()\n", + "\n", + "bm_vals = jax.vmap(lambda t: bm.evaluate(t0, t))(jnp.clip(sol.ts, t0, t1))\n", + "ax2.plot(sol.ts, bm_vals[:, 0], label=\"BM_1\")\n", + "ax2.plot(sol.ts, bm_vals[:, 1], label=\"BM_2\")\n", + "ax2.set_title(\"Brownian motion\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "fd3251c814306cd", + "metadata": {}, + "source": [ + "## Using adaptive time-stepping via the PID-controller\n", + "\n", + "Note that the `SPaRK` solver has an embedded method for error estimation. For solvers like `GeneralShARK`, which do not have an embedded method, we'd instead need to use `HalfSolver(GeneralShARK())` as the solver." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42ca5c5520079b5f", + "metadata": {}, + "outputs": [], + "source": [ + "controller = diffrax.PIDController(\n", + " rtol=0,\n", + " atol=0.005,\n", + " pcoeff=0.2,\n", + " icoeff=0.5,\n", + " dcoeff=0,\n", + " dtmin=2**-12,\n", + " dtmax=0.25,\n", + ")\n", + "\n", + "sol_pid_spark = diffrax.diffeqsolve(\n", + " terms,\n", + " SPaRK(),\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + " stepsize_controller=controller,\n", + " max_steps=2**16,\n", + ")\n", + "accepted_steps = sol_pid_spark.stats[\"num_accepted_steps\"]\n", + "rejected_steps = sol_pid_spark.stats[\"num_rejected_steps\"]\n", + "print(\n", + " f\"Accepted steps: {accepted_steps}, Rejected steps: {rejected_steps},\"\n", + " f\" total steps: {accepted_steps + rejected_steps}\"\n", + ")\n", + "\n", + "# Plot the solution on ax1 and the density of ts on ax2\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 0], label=\"Y_1\")\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 1], label=\"Y_2\")\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 2], label=\"Y_3\")\n", + "ax1.set_title(\"SDE solution\")\n", + "ax1.legend()\n", + "\n", + "# Plot the density of ts\n", + "# sol_pid.ts is padded with inf values at the end, so we remove them\n", + "padding_idx = jnp.argmax(jnp.isinf(sol_pid_spark.ts))\n", + "ts = sol_pid_spark.ts[:padding_idx]\n", + "ax2.hist(ts, bins=100)\n", + "ax2.set_title(\"Density of ts\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6a746e8f4ef4dd66", + "metadata": {}, + "source": "### Using `HalfSolver(GeneralShARK())` for adaptive time-stepping" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fb04ee4d14e25e3", + "metadata": {}, + "outputs": [], + "source": [ + "half_g_shark = diffrax.HalfSolver(diffrax.GeneralShARK())\n", + "\n", + "sol_pid_half_g_shark = diffrax.diffeqsolve(\n", + " terms,\n", + " half_g_shark,\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + " stepsize_controller=controller,\n", + " max_steps=2**16,\n", + ")\n", + "accepted_steps = sol_pid_half_g_shark.stats[\"num_accepted_steps\"]\n", + "rejected_steps = sol_pid_half_g_shark.stats[\"num_rejected_steps\"]\n", + "print(\n", + " f\"Accepted steps: {accepted_steps}, Rejected steps: {rejected_steps},\"\n", + " f\" total steps: {accepted_steps + rejected_steps}\"\n", + ")\n", + "\n", + "# Plot the solution on ax1 and the density of ts on ax2\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol_pid_half_g_shark.ts, sol_pid_half_g_shark.ys[:, 0], label=\"Y_1\")\n", + "ax1.plot(sol_pid_half_g_shark.ts, sol_pid_half_g_shark.ys[:, 1], label=\"Y_2\")\n", + "ax1.plot(sol_pid_half_g_shark.ts, sol_pid_half_g_shark.ys[:, 2], label=\"Y_3\")\n", + "ax1.set_title(\"SDE solution\")\n", + "ax1.legend()\n", + "\n", + "# Plot the density of ts\n", + "# sol_pid.ts is padded with inf values at the end, so we remove them\n", + "padding_idx = jnp.argmax(jnp.isinf(sol_pid_half_g_shark.ts))\n", + "ts = sol_pid_half_g_shark.ts[:padding_idx]\n", + "ax2.hist(ts, bins=100)\n", + "ax2.set_title(\"Density of ts\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "344b5f07d5120128", + "metadata": {}, + "source": [ + "## Solving an SDE for a batch of Brownian motions\n", + "\n", + "When doing Monte Carlo simulations, we often need to solve the same SDE for multiple Brownian motions. We can do this via `jax.vmap`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffe3ced461ebb823", + "metadata": {}, + "outputs": [], + "source": [ + "def get_terms(bm):\n", + " return diffrax.MultiTerm(ode_term, diffrax.ControlTerm(g, bm))\n", + "\n", + "\n", + "# Fix which times we step to (this is equivalent to a constant step size)\n", + "# We do this because the combination of using dt0 and SaveAt(steps=True) pads the\n", + "# output with inf values up to max_steps.\n", + "# Instead we specify exactly which times we want to save at, so Diffrax allocates\n", + "# the correct amount of memory at the outset.\n", + "num_steps = int((t1 - t0) / dt0)\n", + "step_times = jnp.linspace(t0, t1, num_steps + 1, endpoint=True)\n", + "constant_controller = diffrax.StepTo(ts=step_times)\n", + "saveat = diffrax.SaveAt(ts=step_times)\n", + "\n", + "\n", + "# We will vmap over keys\n", + "@jax.jit\n", + "@jax.vmap\n", + "def batch_sde_solve(key):\n", + " bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, bm_shape, key, levy_area=SpaceTimeLevyArea\n", + " )\n", + " terms = get_terms(bm)\n", + " return diffrax.diffeqsolve(\n", + " terms,\n", + " SPaRK(),\n", + " t0,\n", + " t1,\n", + " None,\n", + " y0,\n", + " args,\n", + " saveat=saveat,\n", + " stepsize_controller=constant_controller,\n", + " )\n", + "\n", + "\n", + "# Split the keys and compute the batched solutions\n", + "num_samples = 100\n", + "keys = jr.split(jr.PRNGKey(0), num_samples)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c1206025f30100d", + "metadata": {}, + "outputs": [], + "source": [ + "batch_sols = batch_sde_solve(keys)\n", + "print(\n", + " f\"Shape of batch_sols: \"\n", + " f\"{batch_sols.ys.shape} == {num_samples} x {num_steps + 1} x (dim of Y)\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "71dda42d79d4c553", + "metadata": {}, + "source": [ + "## Optimizing wrt. SDE parameters\n", + "A function with a similar behaviour to `batch_sde_solve` is available in `test.helpers`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d278fc2d438ffc82", + "metadata": {}, + "outputs": [], + "source": [ + "from test.helpers import _batch_sde_solve\n", + "\n", + "import optax\n", + "from jax import Array\n", + "\n", + "\n", + "bm_shape = (2,)\n", + "# Note that _batch_sde_solve doesn't output the whole solution object but just a\n", + "# tuple (ys, num_steps_output). The number of steps is there for\n", + "batch_ys, num_steps_output = _batch_sde_solve(\n", + " keys,\n", + " get_terms,\n", + " bm_shape,\n", + " t0,\n", + " t1,\n", + " y0,\n", + " args,\n", + " SPaRK(),\n", + " SpaceTimeLevyArea,\n", + " None,\n", + " constant_controller,\n", + " bm_tol,\n", + " saveat,\n", + " use_progress_meter=False,\n", + ")\n", + "\n", + "print(\n", + " f\"Shape of batch_ys: \"\n", + " f\"{batch_ys.shape} == {num_samples} x {num_steps + 1} x (dim of Y)\"\n", + ")\n", + "ys_t1 = batch_ys[:, -1]\n", + "mean_t1 = jnp.mean(ys_t1, axis=0)\n", + "var_t1 = jnp.mean(ys_t1**2, axis=0) - mean_t1**2\n", + "print(f\"Stats at t=t1: mean={mean_t1}, var={var_t1}\")\n", + "\n", + "\n", + "# We will optimize for achieving a mean of 0\n", + "def loss(args: tuple[Array, Array, Array]):\n", + " batch_ys, num_steps_output = _batch_sde_solve(\n", + " keys,\n", + " get_terms,\n", + " bm_shape,\n", + " t0,\n", + " t1,\n", + " y0,\n", + " args,\n", + " SPaRK(),\n", + " SpaceTimeLevyArea,\n", + " 2**-7,\n", + " None,\n", + " bm_tol,\n", + " diffrax.SaveAt(t1=True),\n", + " use_progress_meter=False,\n", + " )\n", + " assert batch_ys.shape == (num_samples, 1, 3)\n", + " mean = jnp.mean(batch_ys, axis=(0, 1))\n", + " std = jnp.sqrt(jnp.mean(batch_ys**2, axis=(0, 1)) - mean**2)\n", + " target_mean = jnp.array([0.0, 1.0, 0.0])\n", + " target_stds = 2 * jnp.ones((3,))\n", + " loss = jnp.sqrt(\n", + " jnp.sum((mean - target_mean) ** 2) + jnp.sum((std - target_stds) ** 2)\n", + " )\n", + " return loss\n", + "\n", + "\n", + "# Define the parameters to optimize\n", + "alpha_opt = 0.5 * jnp.ones((3,))\n", + "beta_opt = jnp.array(1.0)\n", + "gamma_opt = jnp.ones((3,))\n", + "args_opt = (alpha_opt, beta_opt, gamma_opt)\n", + "\n", + "# Define the optimizer\n", + "num_steps = 191\n", + "schedule = optax.cosine_decay_schedule(3e-1, num_steps, 1e-2)\n", + "opt = optax.chain(\n", + " optax.scale_by_adam(b1=0.9, b2=0.99, eps=1e-8),\n", + " optax.scale_by_schedule(schedule),\n", + " optax.scale(-1),\n", + ")\n", + "# opt = optax.adam(2e-1)\n", + "opt_state = opt.init(args_opt)\n", + "\n", + "\n", + "@jax.jit\n", + "def step(i, opt_state, args):\n", + " loss_val, grad = jax.value_and_grad(loss)(args)\n", + " updates, opt_state = opt.update(grad, opt_state)\n", + "\n", + " # One way to apply updates\n", + " # args = optax.apply_updates(args, updates)\n", + "\n", + " # Another way to apply updates\n", + " args = jax.tree_util.tree_map(lambda x, u: x + u, args, updates)\n", + "\n", + " return opt_state, args, loss_val\n", + "\n", + "\n", + "for i in range(num_steps):\n", + " opt_state, args_opt, loss_val = step(i, opt_state, args_opt)\n", + " alpha_opt, beta_opt, gamma_opt = args_opt\n", + " if i % 10 == 0:\n", + " print(f\"Step {i}, loss: {loss_val}\")\n", + "\n", + "print(\n", + " f\"Optimal parameters:\\n\"\n", + " f\"alpha={alpha_opt},\"\n", + " f\" beta={beta_opt}, gamma={gamma_opt}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "834651877787c7e6", + "metadata": {}, + "outputs": [], + "source": [ + "batch_ys_opt, num_steps_output = _batch_sde_solve(\n", + " keys,\n", + " get_terms,\n", + " bm_shape,\n", + " t0,\n", + " t1,\n", + " y0,\n", + " args_opt,\n", + " SPaRK(),\n", + " SpaceTimeLevyArea,\n", + " None,\n", + " constant_controller,\n", + " bm_tol,\n", + " saveat,\n", + " use_progress_meter=False,\n", + ")\n", + "ys_t1 = batch_ys_opt[:, -1]\n", + "mean_t1 = jnp.mean(ys_t1, axis=0)\n", + "var_t1 = jnp.mean(ys_t1**2, axis=0) - mean_t1**2\n", + "print(f\"Stats at t=t1: mean={mean_t1}, var={var_t1}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a952251cf67716e8", + "metadata": {}, + "outputs": [], + "source": [ + "# simulate the optimal SDE\n", + "\n", + "bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, bm_shape, jr.key(0), levy_area=SpaceTimeLevyArea\n", + ")\n", + "terms = get_terms(bm)\n", + "sol_optimal = diffrax.diffeqsolve(\n", + " terms,\n", + " SPaRK(),\n", + " t0,\n", + " t1,\n", + " None,\n", + " y0,\n", + " args_opt,\n", + " saveat=saveat,\n", + " stepsize_controller=constant_controller,\n", + ")\n", + "\n", + "fig, ax = plt.subplots(1, 1, figsize=(6, 4))\n", + "ax.plot(sol_optimal.ts, sol_optimal.ys[:, 0], label=\"Y_1\")\n", + "ax.plot(sol_optimal.ts, sol_optimal.ys[:, 1], label=\"Y_2\")\n", + "ax.plot(sol_optimal.ts, sol_optimal.ys[:, 2], label=\"Y_3\")\n", + "ax.set_title(\"SDE solution\")\n", + "ax.legend()\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "fa031b6ebb679028", + "metadata": {}, + "source": [ + "# Table of available SRK methods in Diffrax\n", + "\n", + "## Itô vs Stratonovich SDEs\n", + "Some of the solvers converge to the Itô solution of the SDE and others to the Stratonovich solution. The Itô and Stratonovich solutions coincide iff the SDE has additive noise (as defined below). For other SDEs it is possible to convert them between the Itô and Stratonovich versions using the Itô-Stratonovich correction term.\n", + "\n", + "\n", + "## Noise type\n", + "Depending on the type of noise (i.e. diffusion term) present in the SDE, different SRK methods have different strong orders of convergence. These types of noise are the same for Itô and Stratonovich SDEs.\n", + "\n", + "\n", + "### General noise\n", + "Any SDE of the form (as above):\n", + "$$\n", + " dY_t = f(Y_t, t) dt + g(Y_t, t) dW_t, \n", + "$$\n", + "where $t \\in [0, T]$, $Y_t \\in \\mathbb{R}^e$, and $W$ is a standard Brownian motion on $\\mathbb{R}^d$. We refer to $f: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^e$ as the drift vector field and $g: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^{e \\times d}$ is the diffusion matrix field with columns $g_i$ for $i = 1, \\ldots, d$.\n", + "\n", + "\n", + "### Commutative noise\n", + "We say that the diffusion is commutative when the columns of $g$ commute in the Lie bracket, that is\n", + "$$\n", + " \\frac{d}{dy} g_i(y, t) \\; g_j(y, t) = g_i(y, t) \\, \\frac{d}{dy} g_j(y, t) \\quad \\forall \\, y, t.\n", + "$$\n", + "For example, this holds when $g$ is diagonal or when the dimension of BM is $d=1$.\n", + "\n", + "\n", + "### Additive noise\n", + "We say that the diffusion is additive when $g$ does not depend on $Y_t$ and the SDE can be written as\n", + "$$\n", + " dY_t = f(Y_t, t) dt + g(t) dW_t.\n", + "$$\n", + "Additive noise is a special case of commutative noise. For additive noise SDEs, the Itô and Stratonovich solutions conicide. Some solvers (ShARK, SRA1, SEA) are specifically designed for additive noise SDEs, so for those we do not specify whether they are Itô or Stratonovich." + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "```\n", + "+----------------+-------+------------+----------------------------------+-------------------+----------------+------------------------------------------+\n", + "| | SDE | Lévy | Strong/weak order per noise type | VF evaluations | Embedded error | Recommended for |\n", + "| | type | area +---------+-------------+----------+-------+-----------+ estimation | (and other notes) |\n", + "| | | | General | Commutative | Additive | Drift | Diffusion | | |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| Euler | Itô | BM only | 0.5/1.0 | 0.5/1.0 | 1.0/1.0 | 1 | 1 | No | Itô SDEs when a cheap solver is needed |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| Heun | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | Stratonovich SDEs without space-time LA. |\n", + "| | | | | | | | | | Has weak order 2 for constant noise. |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| ItoMilstein | Itô | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 1 | 1 + deriv | No | Better than Euler for Itô SDEs, but |\n", + "| | | | | | | | ative | | comuptes the derivative of diffusion VF |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| Stratonovich | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 1 | 1 + deriv | No | For commutative Stratonovich SDEs when |\n", + "| Milstein | | | | | | | ative | | space-time Lévy area is not available. |\n", + "| | | | | | | | | | Computes derivative of diffusion VF. |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| ReversibleHeun | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 1 | 1 | Yes | When a reversible solver is needed, e.g. |\n", + "| | | | | | | | | | for Neural SDEs. This method is FSAL. |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| Midpoint | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | Usually Heun should be preferred. |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| ShARK | Addit | space-time | / | / | 1.5/2.0 | 2 | 2 | Yes | Additive noise SDEs |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| SRA1 | Addit | space-time | / | / | 1.5/2.0 | 2 | 2 | Yes | Only slightly worse than ShARK |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| SEA | Addit | space-time | / | / | 1.0/1.0 | 1 | 1 | No | Cheap solver for additive noise SDEs |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| SPaRK | Strat | space-time | 0.5/1.0 | 1.0/1.0 | 1.5/2.0 | 3 | 3 | Yes | General SDEs when embedded error |\n", + "| | | | | | | | | | estimation is needed |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| GeneralShARK | Strat | space-time | 0.5/1.0 | 1.0/1.0 | 1.5/2.0 | 2 | 3 | No | General SDEs when embedded error |\n", + "| | | | | | | | | | estimaiton is not needed |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| SlowRK | Strat | space-time | 0.5/1.0 | 1.5/2.0 | 1.5/2.0 | 2 | 5 | No | Commutative noise SDEs. |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "```" + ], + "id": "4a49d8a22ccfcfe9" + }, + { + "cell_type": "markdown", + "id": "58c5d15087e0b321", + "metadata": {}, + "source": [ + "### Appendix: Using the `WeaklyDiagonalControlTerm`\n", + "\n", + "Suppose we instead have $e = d = 3$ and the diffusion matrix field always returns a diagonal matrix. That is, $\\forall t, y$, the matrix $g(y,t)$ is diagonal. Then we can use the `WeaklyDiagonalControlTerm`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d62ffefdc35dac76", + "metadata": {}, + "outputs": [], + "source": [ + "bm2 = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, (3,), key, levy_area=SpaceTimeLevyArea\n", + ")\n", + "\n", + "\n", + "# Using the regular ControlTerm\n", + "def g_CT(t, y, args):\n", + " return jnp.array(\n", + " [[jnp.abs(jnp.sum(y)), 0.0, 0.0], [0.0, y[0], 0.0], [0.0, 0.0, 10 * t]],\n", + " dtype=y.dtype,\n", + " )\n", + "\n", + "\n", + "diffusion_term_CT = diffrax.ControlTerm(g_CT, bm2)\n", + "\n", + "\n", + "# The same SDE using the WeaklyDiagonalControlTerm\n", + "def g_WD(t, y, args):\n", + " return jnp.array([jnp.abs(jnp.sum(y)), y[0], 10 * t], dtype=y.dtype)\n", + "\n", + "\n", + "diffusion_term_WD = diffrax.WeaklyDiagonalControlTerm(g_WD, bm2)\n", + "\n", + "terms_CT = diffrax.MultiTerm(ode_term, diffusion_term_CT)\n", + "terms_WD = diffrax.MultiTerm(ode_term, diffusion_term_WD)\n", + "\n", + "sol_CT = diffrax.diffeqsolve(\n", + " terms_CT,\n", + " SPaRK(),\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + ")\n", + "sol_WD = diffrax.diffeqsolve(\n", + " terms_WD,\n", + " SPaRK(),\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + ")\n", + "assert jnp.allclose(sol_CT.ys, sol_WD.ys)\n", + "\n", + "# Plot the solution\n", + "fig, ax = plt.subplots(1, 1, figsize=(6, 4))\n", + "ax.plot(sol.ts, sol_WD.ys[:, 0], label=\"Y_1\")\n", + "ax.plot(sol.ts, sol_WD.ys[:, 1], label=\"Y_2\")\n", + "ax.plot(sol.ts, sol_WD.ys[:, 2], label=\"Y_3\")\n", + "ax.set_title(\"SDE solution\")\n", + "ax.legend()\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "62b40efc1e1b19fb", + "metadata": {}, + "source": "### Unstable SDE" + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "9df7b5e23f3162ce", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-03T11:56:21.530191Z", + "start_time": "2024-06-03T11:56:10.490604Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num steps: constant: 2000, adaptive: 1185\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def f_unstable(t, y, args):\n", + " return -(y**3)\n", + "\n", + "\n", + "def g_unstable(t, y, args):\n", + " return jnp.array(10.0, dtype=y.dtype)\n", + "\n", + "\n", + "y0_unstable = jnp.array(0.0)\n", + "\n", + "t1_unstable = 100.0\n", + "\n", + "key_unstable = jr.key(0)\n", + "bm_1d = diffrax.VirtualBrownianTree(\n", + " t0, t1_unstable, 2**-10, (), key_unstable, levy_area=SpaceTimeLevyArea\n", + ")\n", + "\n", + "terms = diffrax.MultiTerm(\n", + " diffrax.ODETerm(f_unstable),\n", + " diffrax.ControlTerm(g_unstable, bm_1d),\n", + ")\n", + "\n", + "dt_unstable = 0.05\n", + "controller = diffrax.PIDController(\n", + " rtol=0,\n", + " atol=10.0,\n", + " pcoeff=0.2,\n", + " icoeff=0.5,\n", + " dcoeff=0,\n", + " dtmin=2**-10,\n", + " dtmax=0.5,\n", + ")\n", + "\n", + "sol_const = diffrax.diffeqsolve(\n", + " terms,\n", + " SPaRK(),\n", + " t0,\n", + " t1_unstable,\n", + " dt_unstable,\n", + " y0_unstable,\n", + " (),\n", + " saveat=diffrax.SaveAt(steps=True),\n", + " max_steps=2**18,\n", + ")\n", + "\n", + "sol_adap = diffrax.diffeqsolve(\n", + " terms,\n", + " SPaRK(),\n", + " t0,\n", + " t1_unstable,\n", + " dt_unstable,\n", + " y0_unstable,\n", + " (),\n", + " saveat=diffrax.SaveAt(steps=True),\n", + " stepsize_controller=controller,\n", + " max_steps=2**18,\n", + ")\n", + "\n", + "print(\n", + " f\"Num steps: constant: {sol_const.stats['num_steps']},\"\n", + " f\" adaptive: {sol_adap.stats['num_steps']}\"\n", + ")\n", + "\n", + "# Plot each solution in a separate figure\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol_const.ts, sol_const.ys, label=\"Constant\")\n", + "ax1.set_title(\"Constant step size\")\n", + "ax1.legend()\n", + "\n", + "ax2.plot(sol_adap.ts, sol_adap.ys, label=\"Adaptive\")\n", + "ax2.set_title(\"Adaptive step size\")\n", + "ax2.legend()\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c713ae477ea7b36", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}