diff --git a/examples/sde_example.ipynb b/examples/sde_example.ipynb new file mode 100644 index 00000000..76078c20 --- /dev/null +++ b/examples/sde_example.ipynb @@ -0,0 +1,889 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-07T13:23:27.783031Z", + "start_time": "2024-06-07T13:23:26.687282Z" + }, + "collapsed": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: JAX_PLATFORM_NAME=cuda\n" + ] + } + ], + "source": [ + "%env JAX_PLATFORM_NAME=cuda\n", + "\n", + "\n", + "import diffrax\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import matplotlib.pyplot as plt\n", + "from diffrax import (\n", + " SpaceTimeLevyArea,\n", + " SPaRK,\n", + ")\n", + "\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "jnp.set_printoptions(precision=4, suppress=True)" + ] + }, + { + "cell_type": "markdown", + "id": "86d4e8b062a81d7e", + "metadata": {}, + "source": [ + "# Simulating SDEs in Diffrax\n", + "\n", + "We will be simulating a Stratonovich SDE of the form:\n", + "$$\n", + " dY_t = f(Y_t, t) dt + g(Y_t, t) \\circ dW_t, \n", + "$$\n", + "where $t \\in [0, T]$, $Y_t \\in \\mathbb{R}^e$, and $W$ is a standard Brownian motion on $\\mathbb{R}^d$. We refer to $f: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^e$ as the drift vector field and $g: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^{e \\times d}$ is the diffusion matrix field. The Stratonovich integral is denoted by $\\circ$.\n", + "\n", + "Our SDE will have the following drift and diffusion terms:\n", + "\\begin{align*}\n", + " f(Y_t, t) &= \\alpha - \\beta Y_t, \\\\\n", + " g(Y_t, t) &= \\gamma \\begin{bmatrix} \\Vert Y_t \\Vert_2 & 0 \\\\ 0 & Y_{t, 1} \\\\ 0 & 10t \\end{bmatrix},\n", + "\\end{align*}\n", + "where $\\alpha, \\gamma \\in \\mathbb{R}^3$ and $\\beta \\in \\mathbb{R}_{\\geq 0}$ are some parameters.\n", + "\n", + "Let's write the SDE in the form that Diffrax expects:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ba23e9cc0370fbac", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-03T10:59:07.293401Z", + "start_time": "2024-06-03T10:59:07.223425Z" + } + }, + "outputs": [], + "source": [ + "# Drift VF (e = 3)\n", + "def f(t, y, args):\n", + " alpha, beta, gamma = args\n", + " beta = jnp.abs(beta)\n", + " assert alpha.shape == (3,)\n", + " return jnp.array(alpha - beta * y, dtype=y.dtype)\n", + "\n", + "\n", + "# Diffusion matrix field (e = 3, d = 2)\n", + "def g(t, y, args):\n", + " alpha, beta, gamma = args\n", + " assert gamma.shape == y.shape == (3,)\n", + " gamma = jnp.reshape(gamma, (3, 1))\n", + " out = gamma * jnp.array(\n", + " [[jnp.sqrt(jnp.sum(y**2)), 0.0], [0.0, 3 * y[0]], [0.0, 20 * t]], dtype=y.dtype\n", + " )\n", + " return out\n", + "\n", + "\n", + "# Initial condition\n", + "y0 = jnp.array([1.0, 1.0, 1.0])\n", + "\n", + "# Args\n", + "alpha = 0.5 * jnp.ones((3,))\n", + "beta = 1.0\n", + "gamma = jnp.ones((3,))\n", + "args = (alpha, beta, gamma)\n", + "\n", + "# Time domain\n", + "t0 = 0.0\n", + "t1 = 2.0" + ] + }, + { + "cell_type": "markdown", + "id": "ef2ff90865907b7d", + "metadata": {}, + "source": [ + "## Brownian motion and its Levy area\n", + "\n", + "Different solvers require different information about the Brownian motion. For example, the `SPaRK` solver requires access to the space-time Levy area of the Brownian motion. The required Levy area for each solver is documented in the table at the end of this notebook, or can be checked via `solver.minimal_levy_area`.\n", + " \n", + "We will use the `VirtualBrownianTree` class to generate the Brownian motion and its Levy area." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4110735158215acc", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-03T10:59:08.656002Z", + "start_time": "2024-06-03T10:59:08.593425Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Minimal levy area for SPaRK: AbstractSpaceTimeLevyArea.\n" + ] + } + ], + "source": [ + "# check minimal levy area\n", + "solver = SPaRK()\n", + "print(f\"Minimal levy area for SPaRK: {solver.minimal_levy_area.__name__}.\")\n", + "\n", + "# Brownian motion\n", + "key = jr.key(0)\n", + "bm_tol = 2**-13\n", + "bm_shape = (2,)\n", + "bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, bm_shape, key, levy_area=SpaceTimeLevyArea\n", + ")\n", + "\n", + "# Defining the terms of the SDE\n", + "ode_term = diffrax.ODETerm(f)\n", + "diffusion_term = diffrax.ControlTerm(g, bm) # Note that the BM is baked into the term\n", + "terms = diffrax.MultiTerm(ode_term, diffusion_term)" + ] + }, + { + "cell_type": "markdown", + "id": "e71db03c5257bd46", + "metadata": {}, + "source": [ + "### Using `diffrax.diffeqsolve` to solve the SDE\n", + "\n", + "We will first use constant steps of size $h = 2^{-9}$ to solve the SDE. It is very important to have $h > \\mathtt{bm_tol}$, where $\\mathtt{bm_tol}$ is the tolerance of the Brownian motion.\n", + "\n", + " We will use the SPaRK method to solve the SDE. ShARK is a stochastic Runge-Kutta method that requires access to space-time Levy area." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1841a462895a32f", + "metadata": {}, + "outputs": [], + "source": [ + "dt0 = 2**-9" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a969e1b9bd9f09", + "metadata": {}, + "outputs": [], + "source": [ + "sol = diffrax.diffeqsolve(\n", + " terms, SPaRK(), t0, t1, dt0, y0, args, saveat=diffrax.SaveAt(steps=True)\n", + ")\n", + "\n", + "# Plotting the solution on ax1 and the BM on ax2\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol.ts, sol.ys[:, 0], label=\"Y_1\")\n", + "ax1.plot(sol.ts, sol.ys[:, 1], label=\"Y_2\")\n", + "ax1.plot(sol.ts, sol.ys[:, 2], label=\"Y_3\")\n", + "ax1.set_title(\"SDE solution\")\n", + "ax1.legend()\n", + "\n", + "bm_vals = jax.vmap(lambda t: bm.evaluate(t0, t))(jnp.clip(sol.ts, t0, t1))\n", + "ax2.plot(sol.ts, bm_vals[:, 0], label=\"BM_1\")\n", + "ax2.plot(sol.ts, bm_vals[:, 1], label=\"BM_2\")\n", + "ax2.set_title(\"Brownian motion\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "fd3251c814306cd", + "metadata": {}, + "source": [ + "## Using adaptive time-stepping via the PID-controller\n", + "\n", + "Note that the `SPaRK` solver has an embedded method for error estimation. For solvers like `GeneralShARK`, which do not have an embedded method, we'd instead need to use `HalfSolver(GeneralShARK())` as the solver." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42ca5c5520079b5f", + "metadata": {}, + "outputs": [], + "source": [ + "controller = diffrax.PIDController(\n", + " rtol=0,\n", + " atol=0.005,\n", + " pcoeff=0.2,\n", + " icoeff=0.5,\n", + " dcoeff=0,\n", + " dtmin=2**-12,\n", + " dtmax=0.25,\n", + ")\n", + "\n", + "sol_pid_spark = diffrax.diffeqsolve(\n", + " terms,\n", + " SPaRK(),\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + " stepsize_controller=controller,\n", + " max_steps=2**16,\n", + ")\n", + "accepted_steps = sol_pid_spark.stats[\"num_accepted_steps\"]\n", + "rejected_steps = sol_pid_spark.stats[\"num_rejected_steps\"]\n", + "print(\n", + " f\"Accepted steps: {accepted_steps}, Rejected steps: {rejected_steps},\"\n", + " f\" total steps: {accepted_steps + rejected_steps}\"\n", + ")\n", + "\n", + "# Plot the solution on ax1 and the density of ts on ax2\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 0], label=\"Y_1\")\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 1], label=\"Y_2\")\n", + "ax1.plot(sol_pid_spark.ts, sol_pid_spark.ys[:, 2], label=\"Y_3\")\n", + "ax1.set_title(\"SDE solution\")\n", + "ax1.legend()\n", + "\n", + "# Plot the density of ts\n", + "# sol_pid.ts is padded with inf values at the end, so we remove them\n", + "padding_idx = jnp.argmax(jnp.isinf(sol_pid_spark.ts))\n", + "ts = sol_pid_spark.ts[:padding_idx]\n", + "ax2.hist(ts, bins=100)\n", + "ax2.set_title(\"Density of ts\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6a746e8f4ef4dd66", + "metadata": {}, + "source": "### Using `HalfSolver(GeneralShARK())` for adaptive time-stepping" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fb04ee4d14e25e3", + "metadata": {}, + "outputs": [], + "source": [ + "half_g_shark = diffrax.HalfSolver(diffrax.GeneralShARK())\n", + "\n", + "sol_pid_half_g_shark = diffrax.diffeqsolve(\n", + " terms,\n", + " half_g_shark,\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + " stepsize_controller=controller,\n", + " max_steps=2**16,\n", + ")\n", + "accepted_steps = sol_pid_half_g_shark.stats[\"num_accepted_steps\"]\n", + "rejected_steps = sol_pid_half_g_shark.stats[\"num_rejected_steps\"]\n", + "print(\n", + " f\"Accepted steps: {accepted_steps}, Rejected steps: {rejected_steps},\"\n", + " f\" total steps: {accepted_steps + rejected_steps}\"\n", + ")\n", + "\n", + "# Plot the solution on ax1 and the density of ts on ax2\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol_pid_half_g_shark.ts, sol_pid_half_g_shark.ys[:, 0], label=\"Y_1\")\n", + "ax1.plot(sol_pid_half_g_shark.ts, sol_pid_half_g_shark.ys[:, 1], label=\"Y_2\")\n", + "ax1.plot(sol_pid_half_g_shark.ts, sol_pid_half_g_shark.ys[:, 2], label=\"Y_3\")\n", + "ax1.set_title(\"SDE solution\")\n", + "ax1.legend()\n", + "\n", + "# Plot the density of ts\n", + "# sol_pid.ts is padded with inf values at the end, so we remove them\n", + "padding_idx = jnp.argmax(jnp.isinf(sol_pid_half_g_shark.ts))\n", + "ts = sol_pid_half_g_shark.ts[:padding_idx]\n", + "ax2.hist(ts, bins=100)\n", + "ax2.set_title(\"Density of ts\")\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "344b5f07d5120128", + "metadata": {}, + "source": [ + "## Solving an SDE for a batch of Brownian motions\n", + "\n", + "When doing Monte Carlo simulations, we often need to solve the same SDE for multiple Brownian motions. We can do this via `jax.vmap`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffe3ced461ebb823", + "metadata": {}, + "outputs": [], + "source": [ + "def get_terms(bm):\n", + " return diffrax.MultiTerm(ode_term, diffrax.ControlTerm(g, bm))\n", + "\n", + "\n", + "# Fix which times we step to (this is equivalent to a constant step size)\n", + "# We do this because the combination of using dt0 and SaveAt(steps=True) pads the\n", + "# output with inf values up to max_steps.\n", + "# Instead we specify exactly which times we want to save at, so Diffrax allocates\n", + "# the correct amount of memory at the outset.\n", + "num_steps = int((t1 - t0) / dt0)\n", + "step_times = jnp.linspace(t0, t1, num_steps + 1, endpoint=True)\n", + "constant_controller = diffrax.StepTo(ts=step_times)\n", + "saveat = diffrax.SaveAt(ts=step_times)\n", + "\n", + "\n", + "# We will vmap over keys\n", + "@jax.jit\n", + "@jax.vmap\n", + "def batch_sde_solve(key):\n", + " bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, bm_shape, key, levy_area=SpaceTimeLevyArea\n", + " )\n", + " terms = get_terms(bm)\n", + " return diffrax.diffeqsolve(\n", + " terms,\n", + " SPaRK(),\n", + " t0,\n", + " t1,\n", + " None,\n", + " y0,\n", + " args,\n", + " saveat=saveat,\n", + " stepsize_controller=constant_controller,\n", + " )\n", + "\n", + "\n", + "# Split the keys and compute the batched solutions\n", + "num_samples = 100\n", + "keys = jr.split(jr.PRNGKey(0), num_samples)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c1206025f30100d", + "metadata": {}, + "outputs": [], + "source": [ + "batch_sols = batch_sde_solve(keys)\n", + "print(\n", + " f\"Shape of batch_sols: \"\n", + " f\"{batch_sols.ys.shape} == {num_samples} x {num_steps + 1} x (dim of Y)\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "71dda42d79d4c553", + "metadata": {}, + "source": [ + "## Optimizing wrt. SDE parameters\n", + "A function with a similar behaviour to `batch_sde_solve` is available in `test.helpers`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d278fc2d438ffc82", + "metadata": {}, + "outputs": [], + "source": [ + "from test.helpers import _batch_sde_solve\n", + "\n", + "import optax\n", + "from jax import Array\n", + "\n", + "\n", + "bm_shape = (2,)\n", + "# Note that _batch_sde_solve doesn't output the whole solution object but just a\n", + "# tuple (ys, num_steps_output). The number of steps is there for\n", + "batch_ys, num_steps_output = _batch_sde_solve(\n", + " keys,\n", + " get_terms,\n", + " bm_shape,\n", + " t0,\n", + " t1,\n", + " y0,\n", + " args,\n", + " SPaRK(),\n", + " SpaceTimeLevyArea,\n", + " None,\n", + " constant_controller,\n", + " bm_tol,\n", + " saveat,\n", + " use_progress_meter=False,\n", + ")\n", + "\n", + "print(\n", + " f\"Shape of batch_ys: \"\n", + " f\"{batch_ys.shape} == {num_samples} x {num_steps + 1} x (dim of Y)\"\n", + ")\n", + "ys_t1 = batch_ys[:, -1]\n", + "mean_t1 = jnp.mean(ys_t1, axis=0)\n", + "var_t1 = jnp.mean(ys_t1**2, axis=0) - mean_t1**2\n", + "print(f\"Stats at t=t1: mean={mean_t1}, var={var_t1}\")\n", + "\n", + "\n", + "# We will optimize for achieving a mean of 0\n", + "def loss(args: tuple[Array, Array, Array]):\n", + " batch_ys, num_steps_output = _batch_sde_solve(\n", + " keys,\n", + " get_terms,\n", + " bm_shape,\n", + " t0,\n", + " t1,\n", + " y0,\n", + " args,\n", + " SPaRK(),\n", + " SpaceTimeLevyArea,\n", + " 2**-7,\n", + " None,\n", + " bm_tol,\n", + " diffrax.SaveAt(t1=True),\n", + " use_progress_meter=False,\n", + " )\n", + " assert batch_ys.shape == (num_samples, 1, 3)\n", + " mean = jnp.mean(batch_ys, axis=(0, 1))\n", + " std = jnp.sqrt(jnp.mean(batch_ys**2, axis=(0, 1)) - mean**2)\n", + " target_mean = jnp.array([0.0, 1.0, 0.0])\n", + " target_stds = 2 * jnp.ones((3,))\n", + " loss = jnp.sqrt(\n", + " jnp.sum((mean - target_mean) ** 2) + jnp.sum((std - target_stds) ** 2)\n", + " )\n", + " return loss\n", + "\n", + "\n", + "# Define the parameters to optimize\n", + "alpha_opt = 0.5 * jnp.ones((3,))\n", + "beta_opt = jnp.array(1.0)\n", + "gamma_opt = jnp.ones((3,))\n", + "args_opt = (alpha_opt, beta_opt, gamma_opt)\n", + "\n", + "# Define the optimizer\n", + "num_steps = 191\n", + "schedule = optax.cosine_decay_schedule(3e-1, num_steps, 1e-2)\n", + "opt = optax.chain(\n", + " optax.scale_by_adam(b1=0.9, b2=0.99, eps=1e-8),\n", + " optax.scale_by_schedule(schedule),\n", + " optax.scale(-1),\n", + ")\n", + "# opt = optax.adam(2e-1)\n", + "opt_state = opt.init(args_opt)\n", + "\n", + "\n", + "@jax.jit\n", + "def step(i, opt_state, args):\n", + " loss_val, grad = jax.value_and_grad(loss)(args)\n", + " updates, opt_state = opt.update(grad, opt_state)\n", + "\n", + " # One way to apply updates\n", + " # args = optax.apply_updates(args, updates)\n", + "\n", + " # Another way to apply updates\n", + " args = jax.tree_util.tree_map(lambda x, u: x + u, args, updates)\n", + "\n", + " return opt_state, args, loss_val\n", + "\n", + "\n", + "for i in range(num_steps):\n", + " opt_state, args_opt, loss_val = step(i, opt_state, args_opt)\n", + " alpha_opt, beta_opt, gamma_opt = args_opt\n", + " if i % 10 == 0:\n", + " print(f\"Step {i}, loss: {loss_val}\")\n", + "\n", + "print(\n", + " f\"Optimal parameters:\\n\"\n", + " f\"alpha={alpha_opt},\"\n", + " f\" beta={beta_opt}, gamma={gamma_opt}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "834651877787c7e6", + "metadata": {}, + "outputs": [], + "source": [ + "batch_ys_opt, num_steps_output = _batch_sde_solve(\n", + " keys,\n", + " get_terms,\n", + " bm_shape,\n", + " t0,\n", + " t1,\n", + " y0,\n", + " args_opt,\n", + " SPaRK(),\n", + " SpaceTimeLevyArea,\n", + " None,\n", + " constant_controller,\n", + " bm_tol,\n", + " saveat,\n", + " use_progress_meter=False,\n", + ")\n", + "ys_t1 = batch_ys_opt[:, -1]\n", + "mean_t1 = jnp.mean(ys_t1, axis=0)\n", + "var_t1 = jnp.mean(ys_t1**2, axis=0) - mean_t1**2\n", + "print(f\"Stats at t=t1: mean={mean_t1}, var={var_t1}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a952251cf67716e8", + "metadata": {}, + "outputs": [], + "source": [ + "# simulate the optimal SDE\n", + "\n", + "bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, bm_shape, jr.key(0), levy_area=SpaceTimeLevyArea\n", + ")\n", + "terms = get_terms(bm)\n", + "sol_optimal = diffrax.diffeqsolve(\n", + " terms,\n", + " SPaRK(),\n", + " t0,\n", + " t1,\n", + " None,\n", + " y0,\n", + " args_opt,\n", + " saveat=saveat,\n", + " stepsize_controller=constant_controller,\n", + ")\n", + "\n", + "fig, ax = plt.subplots(1, 1, figsize=(6, 4))\n", + "ax.plot(sol_optimal.ts, sol_optimal.ys[:, 0], label=\"Y_1\")\n", + "ax.plot(sol_optimal.ts, sol_optimal.ys[:, 1], label=\"Y_2\")\n", + "ax.plot(sol_optimal.ts, sol_optimal.ys[:, 2], label=\"Y_3\")\n", + "ax.set_title(\"SDE solution\")\n", + "ax.legend()\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "fa031b6ebb679028", + "metadata": {}, + "source": [ + "# Table of available SRK methods in Diffrax\n", + "\n", + "## Itô vs Stratonovich SDEs\n", + "Some of the solvers converge to the Itô solution of the SDE and others to the Stratonovich solution. The Itô and Stratonovich solutions coincide iff the SDE has additive noise (as defined below). For other SDEs it is possible to convert them between the Itô and Stratonovich versions using the Itô-Stratonovich correction term.\n", + "\n", + "\n", + "## Noise type\n", + "Depending on the type of noise (i.e. diffusion term) present in the SDE, different SRK methods have different strong orders of convergence. These types of noise are the same for Itô and Stratonovich SDEs.\n", + "\n", + "\n", + "### General noise\n", + "Any SDE of the form (as above):\n", + "$$\n", + " dY_t = f(Y_t, t) dt + g(Y_t, t) dW_t, \n", + "$$\n", + "where $t \\in [0, T]$, $Y_t \\in \\mathbb{R}^e$, and $W$ is a standard Brownian motion on $\\mathbb{R}^d$. We refer to $f: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^e$ as the drift vector field and $g: \\mathbb{R}^e \\times [0, T] \\to \\mathbb{R}^{e \\times d}$ is the diffusion matrix field with columns $g_i$ for $i = 1, \\ldots, d$.\n", + "\n", + "\n", + "### Commutative noise\n", + "We say that the diffusion is commutative when the columns of $g$ commute in the Lie bracket, that is\n", + "$$\n", + " \\frac{d}{dy} g_i(y, t) \\; g_j(y, t) = g_i(y, t) \\, \\frac{d}{dy} g_j(y, t) \\quad \\forall \\, y, t.\n", + "$$\n", + "For example, this holds when $g$ is diagonal or when the dimension of BM is $d=1$.\n", + "\n", + "\n", + "### Additive noise\n", + "We say that the diffusion is additive when $g$ does not depend on $Y_t$ and the SDE can be written as\n", + "$$\n", + " dY_t = f(Y_t, t) dt + g(t) dW_t.\n", + "$$\n", + "Additive noise is a special case of commutative noise. For additive noise SDEs, the Itô and Stratonovich solutions conicide. Some solvers (ShARK, SRA1, SEA) are specifically designed for additive noise SDEs, so for those we do not specify whether they are Itô or Stratonovich." + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "```\n", + "+----------------+-------+------------+----------------------------------+-------------------+----------------+------------------------------------------+\n", + "| | SDE | Lévy | Strong/weak order per noise type | VF evaluations | Embedded error | Recommended for |\n", + "| | type | area +---------+-------------+----------+-------+-----------+ estimation | (and other notes) |\n", + "| | | | General | Commutative | Additive | Drift | Diffusion | | |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| Euler | Itô | BM only | 0.5/1.0 | 0.5/1.0 | 1.0/1.0 | 1 | 1 | No | Itô SDEs when a cheap solver is needed |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| Heun | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | Stratonovich SDEs without space-time LA. |\n", + "| | | | | | | | | | Has weak order 2 for constant noise. |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| ItoMilstein | Itô | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 1 | 1 + deriv | No | Better than Euler for Itô SDEs, but |\n", + "| | | | | | | | ative | | comuptes the derivative of diffusion VF |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| Stratonovich | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 1 | 1 + deriv | No | For commutative Stratonovich SDEs when |\n", + "| Milstein | | | | | | | ative | | space-time Lévy area is not available. |\n", + "| | | | | | | | | | Computes derivative of diffusion VF. |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| ReversibleHeun | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 1 | 1 | Yes | When a reversible solver is needed, e.g. |\n", + "| | | | | | | | | | for Neural SDEs. This method is FSAL. |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| Midpoint | Strat | BM only | 0.5/1.0 | 1.0/1.0 | 1.0/1.0 | 2 | 2 | Yes | Usually Heun should be preferred. |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| ShARK | Addit | space-time | / | / | 1.5/2.0 | 2 | 2 | Yes | Additive noise SDEs |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| SRA1 | Addit | space-time | / | / | 1.5/2.0 | 2 | 2 | Yes | Only slightly worse than ShARK |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| SEA | Addit | space-time | / | / | 1.0/1.0 | 1 | 1 | No | Cheap solver for additive noise SDEs |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| SPaRK | Strat | space-time | 0.5/1.0 | 1.0/1.0 | 1.5/2.0 | 3 | 3 | Yes | General SDEs when embedded error |\n", + "| | | | | | | | | | estimation is needed |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| GeneralShARK | Strat | space-time | 0.5/1.0 | 1.0/1.0 | 1.5/2.0 | 2 | 3 | No | General SDEs when embedded error |\n", + "| | | | | | | | | | estimaiton is not needed |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "| SlowRK | Strat | space-time | 0.5/1.0 | 1.5/2.0 | 1.5/2.0 | 2 | 5 | No | Commutative noise SDEs. |\n", + "+----------------+-------+------------+---------+-------------+----------+-------+-----------+----------------+------------------------------------------+\n", + "```" + ], + "id": "4a49d8a22ccfcfe9" + }, + { + "cell_type": "markdown", + "id": "58c5d15087e0b321", + "metadata": {}, + "source": [ + "### Appendix: Using the `WeaklyDiagonalControlTerm`\n", + "\n", + "Suppose we instead have $e = d = 3$ and the diffusion matrix field always returns a diagonal matrix. That is, $\\forall t, y$, the matrix $g(y,t)$ is diagonal. Then we can use the `WeaklyDiagonalControlTerm`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d62ffefdc35dac76", + "metadata": {}, + "outputs": [], + "source": [ + "bm2 = diffrax.VirtualBrownianTree(\n", + " t0, t1, bm_tol, (3,), key, levy_area=SpaceTimeLevyArea\n", + ")\n", + "\n", + "\n", + "# Using the regular ControlTerm\n", + "def g_CT(t, y, args):\n", + " return jnp.array(\n", + " [[jnp.abs(jnp.sum(y)), 0.0, 0.0], [0.0, y[0], 0.0], [0.0, 0.0, 10 * t]],\n", + " dtype=y.dtype,\n", + " )\n", + "\n", + "\n", + "diffusion_term_CT = diffrax.ControlTerm(g_CT, bm2)\n", + "\n", + "\n", + "# The same SDE using the WeaklyDiagonalControlTerm\n", + "def g_WD(t, y, args):\n", + " return jnp.array([jnp.abs(jnp.sum(y)), y[0], 10 * t], dtype=y.dtype)\n", + "\n", + "\n", + "diffusion_term_WD = diffrax.WeaklyDiagonalControlTerm(g_WD, bm2)\n", + "\n", + "terms_CT = diffrax.MultiTerm(ode_term, diffusion_term_CT)\n", + "terms_WD = diffrax.MultiTerm(ode_term, diffusion_term_WD)\n", + "\n", + "sol_CT = diffrax.diffeqsolve(\n", + " terms_CT,\n", + " SPaRK(),\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + ")\n", + "sol_WD = diffrax.diffeqsolve(\n", + " terms_WD,\n", + " SPaRK(),\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + ")\n", + "assert jnp.allclose(sol_CT.ys, sol_WD.ys)\n", + "\n", + "# Plot the solution\n", + "fig, ax = plt.subplots(1, 1, figsize=(6, 4))\n", + "ax.plot(sol.ts, sol_WD.ys[:, 0], label=\"Y_1\")\n", + "ax.plot(sol.ts, sol_WD.ys[:, 1], label=\"Y_2\")\n", + "ax.plot(sol.ts, sol_WD.ys[:, 2], label=\"Y_3\")\n", + "ax.set_title(\"SDE solution\")\n", + "ax.legend()\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "62b40efc1e1b19fb", + "metadata": {}, + "source": "### Unstable SDE" + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "9df7b5e23f3162ce", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-03T11:56:21.530191Z", + "start_time": "2024-06-03T11:56:10.490604Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num steps: constant: 2000, adaptive: 1185\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhAAAAKqCAYAAABrUWeaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAD6VElEQVR4nOydd3gU1frHv7spm0KKQCCJhK4CCqgoVQUEDYqFKygoXkARyw+8esEC1wLYsIsFFe5FUEQpSlFUelHpVYoQAekhCS2FhCSb3fP7YzOzU860zW52N7yf58mT3dkzZ95pZ95527ExxhgIgiAIgiAsYA+2AARBEARBhB+kQBAEQRAEYRlSIAiCIAiCsAwpEARBEARBWIYUCIIgCIIgLEMKBEEQBEEQliEFgiAIgiAIy5ACQRAEQRCEZUiBIAiCIAjCMqRAEARB+IkhQ4agcePGwRaDIKoFUiAIwo8cPHgQjz32GJo2bYqYmBgkJiaiS5cu+PDDD3HhwoWgyvbzzz9j3Lhx1bKtdevWYdy4ccjPz/dbn2+88QYWLFjgt/4IgqgaNpoLgyD8w08//YR7770XDocDgwYNwlVXXYXy8nL8/vvv+P777zFkyBBMmTIlaPKNGDECkyZNQnXc8u+++y6effZZHDp0yG9v5LVq1UK/fv0wffp0v/QXCJxOJ9xuNxwOR7BFIYiAExlsAQiiJnDo0CEMGDAAjRo1wsqVK5GWlib+Nnz4cBw4cAA//fRTECUkqoOoqKhgi0AQ1QcjCKLKPP744wwAW7t2ran2TqeTvfLKK6xp06YsOjqaNWrUiI0ZM4aVlpbK2jVq1Ij17t2b/fbbb+z6669nDoeDNWnShH355ZeyduXl5WzcuHGsefPmzOFwsNq1a7MuXbqwpUuXMsYYGzx4MAOg+hN45513WKdOnVjt2rVZTEwMu/baa9ncuXNVcgNgw4cPZ/Pnz2dXXnkli46OZq1atWK//PKL2Gbs2LHcbR06dEjzePz111/snnvuYfXr12cOh4NdeumlrH///iw/P1/crvJv8ODB4vrHjx9nDz30EKtXr54o09SpU2XbWLVqFQPAZs2axcaMGcPq16/P4uLi2J133smOHj2qf8IYY4WFheypp55ijRo1YtHR0SwlJYX17NmTbd26VWwzePBg1qhRI/F7165dubIDYNOmTRPbnTt3jj311FOsQYMGLDo6mjVr1oy9+eabzOVyGcpFEMGCLBAE4Qd+/PFHNG3aFJ07dzbV/pFHHsGXX36Jfv36YdSoUdi4cSMmTJiAvXv3Yv78+bK2Bw4cQL9+/TB06FAMHjwYX3zxBYYMGYJ27drhyiuvBACMGzcOEyZMwCOPPIL27dujsLAQW7ZswbZt23DLLbfgscceQ3Z2NpYtW4YZM2ao5Pnwww9x1113YeDAgSgvL8esWbNw7733YtGiRejdu7es7e+//4558+bh//7v/5CQkICPPvoIffv2xdGjR1GnTh3cc889+Ouvv/Dtt9/igw8+QN26dQEAKSkp3GNRXl6OzMxMlJWV4cknn0RqaipOnDiBRYsWIT8/H0lJSZgxY4a4b48++igAoFmzZgCA3NxcdOzYETabDSNGjEBKSgp++eUXDB06FIWFhXj66adl23v99ddhs9nw/PPPIy8vDxMnTkTPnj2xY8cOxMbGap6zxx9/HN999x1GjBiBVq1a4cyZM/j999+xd+9eXHvttdx1XnjhBTzyyCOyZV9//TWWLFmCevXqAQBKSkrQtWtXnDhxAo899hgaNmyIdevWYcyYMTh58iQmTpyoKRNBBJVgazAEEe4UFBQwAOzuu+821X7Hjh0MAHvkkUdky5955hkGgK1cuVJc1qhRIwaA/frrr+KyvLw85nA42KhRo8Rlbdu2Zb1799bd7vDhw5nWLV9SUiL7Xl5ezq666ip28803y5YDYNHR0ezAgQPisj/++IMBYB9//LG47J133jG0Oghs376dAeBaPKTEx8fLrA4CQ4cOZWlpaez06dOy5QMGDGBJSUnivgkWiEsvvZQVFhaK7ebMmcMAsA8//FB3+0lJSWz48OG6bZQWCCVr165lUVFR7OGHHxaXvfrqqyw+Pp799ddfsrajR49mERERpqwjBBEMKAuDIKpIYWEhACAhIcFU+59//hkAMHLkSNnyUaNGAYAqVqJVq1a48cYbxe8pKSm44oor8Pfff4vLkpOTsWfPHuzfv9/6DgCyN+9z586hoKAAN954I7Zt26Zq27NnT/HtHwDatGmDxMREmTxWSEpKAgAsWbIEJSUlltZljOH777/HnXfeCcYYTp8+Lf5lZmaioKBAtQ+DBg2Snat+/fohLS1NPC9aJCcnY+PGjcjOzrYko0BOTg769euHq6++Gp9++qm4fO7cubjxxhtxySWXyOTv2bMnXC4Xfv31V5+2RxCBpkYrEL/++ivuvPNOpKenw2azWU4BW716Ne6++26kpaUhPj4eV199NWbOnClr061bN9hsNtWfYPZ1Op14/vnn0bp1a8THxyM9PR2DBg3yeRAiQo/ExEQAQFFRkan2R44cgd1uR/PmzWXLU1NTkZycjCNHjsiWN2zYUNXHJZdcgnPnzonfX3nlFeTn5+Pyyy9H69at8eyzz2Lnzp2m92HRokXo2LEjYmJiULt2baSkpOCzzz5DQUGBqq0ZeazQpEkTjBw5Ev/73/9Qt25dZGZmYtKkSdxtKzl16hTy8/MxZcoUpKSkyP4eeughAEBeXp5sncsuu0z23WazoXnz5jh8+LDutt5++23s3r0bGRkZaN++PcaNG2daaaqoqMB9990Hl8uFefPmybI09u/fj8WLF6vk79mzJ1d+gggVarQCUVxcjLZt22LSpEk+rb9u3Tq0adMG33//PXbu3ImHHnoIgwYNwqJFi8Q28+bNw8mTJ8W/3bt3IyIiAvfeey8Aj39z27ZteOmll7Bt2zbMmzcPWVlZuOuuu/yyj0TwSUxMRHp6Onbv3m1pPZvNZqpdREQEdzmTpGPedNNNOHjwIL744gtcddVV+N///odrr70W//vf/wz7/+2333DXXXchJiYGn376KX7++WcsW7YMDzzwADfl04w8Vnnvvfewc+dO/Oc//8GFCxfwr3/9C1deeSWOHz+uu57b7QYAPPjgg1i2bBn3r0uXLj7LJeW+++7D33//jY8//hjp6el45513cOWVV+KXX34xXPfZZ5/F+vXrMWfOHDRo0EC1D7fccoum/H379vWL/AThd4LqQKlGALD58+fLlpWWlrJRo0ax9PR0FhcXx9q3b89WrVql28/tt9/OHnroIc3fP/jgA5aQkMDOnz+v2WbTpk0MADty5IiVXSBCmEcffZQBYOvWrTNs+8YbbzAA7M8//5Qtz8nJYQBksQ1CFoaSrl27sq5du2puo6ioiF1zzTXs0ksvFZeNGDGCGwPx1FNPsdjYWFUGyAMPPKBqj8osDCWNGjWSxSe8++67pmMgeKxdu5YBYC+88IK4rFatWqoYiIqKCpaQkMDuv/9+wz6FGIgxY8bIlrvdbpaWlsYyMzMtyZibm8suvfRS1qVLF3EZLwbi22+/ZQDYxIkTuf20atWKderUydK2CSIUqNEWCCNGjBiB9evXY9asWdi5cyfuvfde9OrVS9ePXFBQgNq1a2v+PnXqVAwYMADx8fG6fdhsNiQnJ1dFfCKEeO655xAfH49HHnkEubm5qt8PHjyIDz/8EABw++23A4Aquv79998HAFXWgxnOnDkj+16rVi00b94cZWVl4jLhmlRWh4yIiIDNZoPL5RKXHT58uEpVH7W2xaOwsBAVFRWyZa1bt4bdblfJz5O9b9+++P7777kWoFOnTqmWffXVVzJ303fffYeTJ0/itttu05TR5XKpXCr16tVDenq6TEYlu3fvxiOPPIIHH3wQTz31FLfNfffdh/Xr12PJkiWq3/Lz81XHhiBChYs2jfPo0aOYNm0ajh49ivT0dADAM888g8WLF2PatGl44403VOvMmTMHmzdvxuTJk7l9btq0Cbt378bUqVM1t1taWornn38e999/v+g7J8KfZs2a4ZtvvkH//v3RsmVLWSXKdevWYe7cuRgyZAgAoG3bthg8eDCmTJmC/Px8dO3aFZs2bcKXX36JPn36oHv37pa336pVK3Tr1g3t2rVD7dq1sWXLFjHlUKBdu3YAgH/961/IzMxEREQEBgwYgN69e+P9999Hr1698MADDyAvLw+TJk1C8+bNLcVRSBG29cILL2DAgAGIiorCnXfeyVWsV65ciREjRuDee+/F5ZdfjoqKCsyYMUNUDqR9Ll++HO+//z7S09PRpEkTdOjQAW+++SZWrVqFDh06YNiwYWjVqhXOnj2Lbdu2Yfny5Th79qxse7Vr18YNN9yAhx56CLm5uZg4cSKaN2+OYcOGae5PUVERGjRogH79+qFt27aoVasWli9fjs2bN+O9997TXE+Iw7jpppvw9ddfy37r3LkzmjZtimeffRY//PAD7rjjDjE9t7i4GLt27cJ3332Hw4cPi6mwBBFSBNsEUl1A4cJYtGgRA8Di4+Nlf5GRkey+++5Trb9y5UoWFxenKuAj5dFHH2WtW7fW/L28vJzdeeed7JprrmEFBQVV2h8iNPnrr7/YsGHDWOPGjVl0dDRLSEhgXbp0YR9//LHMReB0Otn48eNZkyZNWFRUFMvIyNAtJKVE6cJ47bXXWPv27VlycjKLjY1lLVq0YK+//jorLy8X21RUVLAnn3ySpaSkMJvNJnNPTJ06lV122WXM4XCwFi1asGnTpokFoaTApAuDMU964qWXXsrsdruuO+Pvv/9mDz/8MGvWrBmLiYlhtWvXZt27d2fLly+Xtdu3bx+76aabWGxsrKqQVG5uLhs+fDjLyMhgUVFRLDU1lfXo0YNNmTJFbCO4ML799ls2ZswYVq9ePRYbG8t69+5t6E4sKytjzz77LGvbti1LSEhg8fHxrG3btuzTTz+VtVO6MIQ0XN6ftJBUUVERGzNmDGvevDmLjo5mdevWZZ07d2bvvvuu7BwSRChx0cyFYbPZMH/+fPTp0wcAMHv2bAwcOBB79uxRBYXVqlULqamp4vc1a9aIb2lCERslxcXFSE9PxyuvvMI1VTqdTjEIa+XKlahTp47/do4gCENWr16N7t27Y+7cuejXr1+wxSGIsOeidWFcc801cLlcyMvLk+XYK1m9ejXuuOMOvPXWW5rKA+DJ5S4rK8ODDz6o+k1QHvbv349Vq1aR8kAQBEGEPTVagTh//jwOHDggfj906BB27NiB2rVr4/LLL8fAgQMxaNAgvPfee7jmmmtw6tQprFixAm3atEHv3r2xatUq3HHHHXjqqafQt29f5OTkAACio6NVgZRTp05Fnz59VMqB0+lEv379sG3bNixatAgul0vsp3bt2oiOjg7wUSAIgiCIABBsH0ogEXyeyj/Bd1peXs5efvll1rhxYxYVFcXS0tLYP/7xD7Zz507GmPYERMr0uX379jEA4sRFUg4dOqTpAzVKGSUIwn8I44FRyWyCIMxx0cRAEARBEAThPy7qOhAEQRAEQfgGKRAEQRAEQVimxgVRut1uZGdnIyEhwfRcAwRBEARBeOa0KSoqQnp6Oux2fRtDjVMgsrOzkZGREWwxCIIgCCJsOXbsmGriNyU1ToFISEgA4Nl5KhVNEARBEOYpLCxERkaG+CzVo8YpEILbIjExkRQIgiAIgvABMyEAFERJEARBEIRlSIEgCIIgCMIypEAQBEEQBGEZUiAIgiAIgrAMKRAEQRAEQViGFAiCIAiCICxTLQrEpEmT0LhxY8TExKBDhw7YtGmTbvu5c+eiRYsWiImJQevWrfHzzz9Xh5gEQRAEQZgk4ArE7NmzMXLkSIwdOxbbtm1D27ZtkZmZiby8PG77devW4f7778fQoUOxfft29OnTB3369MHu3bsDLSpBEARBECYJ+HTeHTp0wPXXX49PPvkEgGeuioyMDDz55JMYPXq0qn3//v1RXFyMRYsWics6duyIq6++Gp9//rnh9goLC5GUlISCggK/FZIqr3DDTbOeEwRBECGEI9Lu9zmfrDxDA1qJsry8HFu3bsWYMWPEZXa7HT179sT69eu566xfvx4jR46ULcvMzMSCBQu47cvKylBWViZ+LywsrLrgCt78ZR++WHvI7/0SBEEQhK+0zUjG/Cc6w24PzsSRAXVhnD59Gi6XC/Xr15ctr1+/PnJycrjr5OTkWGo/YcIEJCUliX80kRZBEARxMfDHsXwUXHAGbfthPxfGmDFjZBYLYSIQf/L8bVdg1K2X+7VPgiAIgvAFF2NoM24pACCYzvWAKhB169ZFREQEcnNzZctzc3ORmprKXSc1NdVSe4fDAYfD4R+BNXBERsAR9qoWQRAEURNwu71qQ4DDGHUJqAsjOjoa7dq1w4oVK8RlbrcbK1asQKdOnbjrdOrUSdYeAJYtW6bZniAIgiAuJqRxkzXWAgEAI0eOxODBg3Hdddehffv2mDhxIoqLi/HQQw8BAAYNGoRLL70UEyZMAAA89dRT6Nq1K9577z307t0bs2bNwpYtWzBlypRAi0oQBEEQIY808yKYCYIBVyD69++PU6dO4eWXX0ZOTg6uvvpqLF68WAyUPHr0KOx2ryGkc+fO+Oabb/Diiy/iP//5Dy677DIsWLAAV111VaBFJQiCIIiwggXRBhHwOhDVTSDqQBAEQRBEKNF0zE9wM2Djf3qgfmKM3/q18gyluTAIgiAIIswQ3BjBNAGQAkEQBEEQYYYQBRFMFwYpEARBEAQRZghxlGSBIAiCIAjCNLZKG0QwgxhJgSAIgiCIMEOwQEiLSlU3pEAQBEEQRJjh50k4fYIUCIIgCIIIM0QXBsVAEARBEARhFjGIkrIwCIIgCIIwi5jGSRYIgiAIgiDMYq80Qbhr6mycBEEQBEEEANGFETxIgSAIgiCIMINcGARBEARBWMY7pTe5MAiCIAiCMAmVsiYIgiAIwjJCECXFQBAEQRAEYRrBgUFZGARBEARBmIZcGARBEARB+ACVsiYIgiAIwiJUypogCIIgCMvU+DoQZ8+excCBA5GYmIjk5GQMHToU58+f112nW7dusNlssr/HH388kGISBEEQRFghZmEEUYGIDGTnAwcOxMmTJ7Fs2TI4nU489NBDePTRR/HNN9/orjds2DC88sor4ve4uLhAikkQBEEQYUUouDACpkDs3bsXixcvxubNm3HdddcBAD7++GPcfvvtePfdd5Genq65blxcHFJTUwMlGkEQBEGENTXahbF+/XokJyeLygMA9OzZE3a7HRs3btRdd+bMmahbty6uuuoqjBkzBiUlJZpty8rKUFhYKPsjCIIgiJqMLQQKSQXMApGTk4N69erJNxYZidq1ayMnJ0dzvQceeACNGjVCeno6du7cieeffx5ZWVmYN28et/2ECRMwfvx4v8pOEARBEOEAC6IJwrICMXr0aLz11lu6bfbu3euzQI8++qj4uXXr1khLS0OPHj1w8OBBNGvWTNV+zJgxGDlypPi9sLAQGRkZPm+fIAiCIEIde6X/wB1OQZSjRo3CkCFDdNs0bdoUqampyMvLky2vqKjA2bNnLcU3dOjQAQBw4MABrgLhcDjgcDhM90cQBEEQ4Y4NwZ+N07ICkZKSgpSUFMN2nTp1Qn5+PrZu3Yp27doBAFauXAm32y0qBWbYsWMHACAtLc2qqARBEARRI6nRpaxbtmyJXr16YdiwYdi0aRPWrl2LESNGYMCAAWIGxokTJ9CiRQts2rQJAHDw4EG8+uqr2Lp1Kw4fPowffvgBgwYNwk033YQ2bdoESlSCIAiCCCuCb38IcCGpmTNnokWLFujRowduv/123HDDDZgyZYr4u9PpRFZWlphlER0djeXLl+PWW29FixYtMGrUKPTt2xc//vhjIMUkCIIgiLDCVtMLSdWuXVu3aFTjxo1lEaQZGRlYs2ZNIEUiCIIgiLBHcGHQdN4EQRAEQZimRheSIgiCIAgiMHgLSZEFgiAIgiAIkwgWiGBGUZICQRAEQRBhhncyreBBCgRBEARBhBmhMJ03KRAEQRAEEaZQFgZBEARBEKYJhdk4SYEgCIIgiDDDm8ZJFgiCIAiCIExCQZQEQRAEQVjGHgIaBCkQBEEQBBFmUClrgiAIgiAsQ6WsCYIgCIKwDmVhEARBEARhFcrCIAiCIAjCMvbgx1CSAkEQBEEQ4YZYSIosEARBEARBmIWCKAmCIAiCsEwIlIEgBYIgCIIgwg0baDZOgiAIgiAs4rVA1MAYiNdffx2dO3dGXFwckpOTTa3DGMPLL7+MtLQ0xMbGomfPnti/f3+gRCQIgiCIsERUIGqiBaK8vBz33nsvnnjiCdPrvP322/joo4/w+eefY+PGjYiPj0dmZiZKS0sDJSZBEARBhB2CCyOYpawjA9Xx+PHjAQDTp0831Z4xhokTJ+LFF1/E3XffDQD46quvUL9+fSxYsAADBgwIlKgEQRAEEVYIFohgEjIxEIcOHUJOTg569uwpLktKSkKHDh2wfv16zfXKyspQWFgo+yMIgiCImkyNdmFYJScnBwBQv3592fL69euLv/GYMGECkpKSxL+MjIyAykkQBEEQwUbMwgiXIMrRo0fDZrPp/u3bty9QsnIZM2YMCgoKxL9jx45V6/YJgiAIoroJBQuEpRiIUaNGYciQIbptmjZt6pMgqampAIDc3FykpaWJy3Nzc3H11VdrrudwOOBwOHzaJkEQBEGEI0Ipa3e4KBApKSlISUkJiCBNmjRBamoqVqxYISoMhYWF2Lhxo6VMDoIgCIKo6dTo2TiPHj2KHTt24OjRo3C5XNixYwd27NiB8+fPi21atGiB+fPnA/BoU08//TRee+01/PDDD9i1axcGDRqE9PR09OnTJ1BiEgRBEETYEQqlrAOWxvnyyy/jyy+/FL9fc801AIBVq1ahW7duAICsrCwUFBSIbZ577jkUFxfj0UcfRX5+Pm644QYsXrwYMTExgRKTIAiCIMIOMYsziBqEjQXT/hEACgsLkZSUhIKCAiQmJgZbHIIgCILwO0Onb8aKfXl4q29r9L++od/6tfIMDZk0ToIgCIIgzCG4MIIZREkKBEEQBEGEHTQbJ0EQBEEQFqnRs3ESBEEQBBEYvGmcwZOBFAiCIAiCCDPsNqGUdRBlCOK2CYIgCILwAW8pa3JhEARBEARhklCYC4MUCIIgCIIIM8TZOMkCQRAEQRCEaUKglDUpEARBEAQRZohBlOTCIAiCIAjCLGIaZxBlIAWCIAiCIMIMysIgCIIgCMIyVEiKIAiCIAjL2MRCUmSBIAiCIAjCJFQHgiAIgiAIy4h1IIIoAykQBEEQBBFmCBYINwVREgRBEARhFgqiJAiCIAjCMoIFIpiQAkEQBEEQYYa3EmUNdGG8/vrr6Ny5M+Li4pCcnGxqnSFDhsBms8n+evXqFSgRCYIgCCIsCYUsjMhAdVxeXo57770XnTp1wtSpU02v16tXL0ybNk387nA4AiEeQRAEQYQxHg3CXRMViPHjxwMApk+fbmk9h8OB1NTUAEhEEARBEDUD0QJBhaS8rF69GvXq1cMVV1yBJ554AmfOnNFtX1ZWhsLCQtkfQRAEQdRkKAtDQa9evfDVV19hxYoVeOutt7BmzRrcdtttcLlcmutMmDABSUlJ4l9GRkY1SkwQBEEQ1Y8YRBlMGaw0Hj16tCrIUfm3b98+n4UZMGAA7rrrLrRu3Rp9+vTBokWLsHnzZqxevVpznTFjxqCgoED8O3bsmM/bJwiCIIhwwBYCJghLMRCjRo3CkCFDdNs0bdq0KvKo+qpbty4OHDiAHj16cNs4HA4KtCQIgiAuKkT9IYgyWFIgUlJSkJKSEihZVBw/fhxnzpxBWlpatW2TIAiCIEIdYTbOGlnK+ujRo9ixYweOHj0Kl8uFHTt2YMeOHTh//rzYpkWLFpg/fz4A4Pz583j22WexYcMGHD58GCtWrMDdd9+N5s2bIzMzM1BiEgRBEETYUiPrQLz88sv48ssvxe/XXHMNAGDVqlXo1q0bACArKwsFBQUAgIiICOzcuRNffvkl8vPzkZ6ejltvvRWvvvoquSgIgiAIQoI3jTN4BEyBmD59umENCGkJztjYWCxZsiRQ4hAEQRBEjcFbyjqIMgRv0wRBEARB+II3iLIGxkAQBEEQBBEYQmEuDFIgCIIgCCLMsNXk2TgJgiAIgggMIVBHihQIgiAIggg3bOFWypogCIIgiOBDMRAEQRAEQVhGcGHUyEqUBEEQBEEEBnEyrSBCCgRBEARBhBk2UBYGQRAEQRAWsYdAKWtSIAiCIAgi3KBS1gRBEARBWIWCKAmCIAiCsEwozMZJCgRBEARBhBneIMrgyUAKBEEQBEGEGXYxjZNcGARBEARBmIQqURIEQRAEYRkbZWEQBEEQBOErlIVBEARBEIRpKAuDIAiCIAjL2GuyC+Pw4cMYOnQomjRpgtjYWDRr1gxjx45FeXm57nqlpaUYPnw46tSpg1q1aqFv377Izc0NlJgEQRAEEXYISRisJmZh7Nu3D263G5MnT8aePXvwwQcf4PPPP8d//vMf3fX+/e9/48cff8TcuXOxZs0aZGdn45577gmUmARBEAQRdti8GkTQiAxUx7169UKvXr3E702bNkVWVhY+++wzvPvuu9x1CgoKMHXqVHzzzTe4+eabAQDTpk1Dy5YtsWHDBnTs2DFQ4hIEQRBE2CAUkrpogigLCgpQu3Ztzd+3bt0Kp9OJnj17istatGiBhg0bYv369dUhIkEQBEGEPKEQRBkwC4SSAwcO4OOPP9a0PgBATk4OoqOjkZycLFtev3595OTkcNcpKytDWVmZ+L2wsNAv8hIEQRBEqBKWdSBGjx4Nm82m+7dv3z7ZOidOnECvXr1w7733YtiwYX4THgAmTJiApKQk8S8jI8Ov/RMEQRBEqBECIRDWLRCjRo3CkCFDdNs0bdpU/JydnY3u3bujc+fOmDJliu56qampKC8vR35+vswKkZubi9TUVO46Y8aMwciRI8XvhYWFpEQQBEEQNRpvKevgqRCWFYiUlBSkpKSYanvixAl0794d7dq1w7Rp02C36xs82rVrh6ioKKxYsQJ9+/YFAGRlZeHo0aPo1KkTdx2HwwGHw2FtJwiCIAgijBEtEOHkwjDLiRMn0K1bNzRs2BDvvvsuTp06hZycHFksw4kTJ9CiRQts2rQJAJCUlIShQ4di5MiRWLVqFbZu3YqHHnoInTp1ogwMgiAIgqhEjIEIohMjYEGUy5Ytw4EDB3DgwAE0aNBA9ptgcnE6ncjKykJJSYn42wcffAC73Y6+ffuirKwMmZmZ+PTTTwMlJkEQBEGEHfYQmI3TxoLpQAkAhYWFSEpKQkFBARITE4MtDkEQBEH4nRkbjuClBbvR68pUfP7Pdn7r18ozlObCIAiCIIgwo0aXsiYIgiAIIjAIWRjumhhESRAEQRBEYBBKWdfILAyCIAiCIAKDOJkWuTAIgiAIgjBLKGRhkAJBEARBEGGG6MIIogykQBAEQRBEuBECpaxJgSAIgiCIMEMIgaAsDIIgCIIgTOMtZR08SIEgCIIgiDDDTi4MgiAIgiCs4k3jDB6kQBAEQRBEmEGFpAiCIAiCsIy3lDW5MAiCIAiCsAhZIAiCIAiCMI1dzMIgCwRBEARBECaxUSlrgiAIgiCsQqWsCYIgCIKwjI3qQBAEQRAEYRWhDAS5MAiCIAiCME2NLmV9+PBhDB06FE2aNEFsbCyaNWuGsWPHory8XHe9bt26wWazyf4ef/zxQIlJEARBEGFHKLgwIgPV8b59++B2uzF58mQ0b94cu3fvxrBhw1BcXIx3331Xd91hw4bhlVdeEb/HxcUFSkyCIAiCCDtEF0YQZQiYAtGrVy/06tVL/N60aVNkZWXhs88+M1Qg4uLikJqaGijRCIIgCCKsEV0YF0sMREFBAWrXrm3YbubMmahbty6uuuoqjBkzBiUlJdUgHUEQBEGEB94gyhrowlBy4MABfPzxx4bWhwceeACNGjVCeno6du7cieeffx5ZWVmYN28et31ZWRnKysrE74WFhX6VmyAIgiBCDXvl639YuTBGjx6Nt956S7fN3r170aJFC/H7iRMn0KtXL9x7770YNmyY7rqPPvqo+Ll169ZIS0tDjx49cPDgQTRr1kzVfsKECRg/frzFvSAIgiCI8CUUZuO0MYv2j1OnTuHMmTO6bZo2bYro6GgAQHZ2Nrp164aOHTti+vTpsNuteU2Ki4tRq1YtLF68GJmZmarfeRaIjIwMFBQUIDEx0dK2CIIgCCIcWJWVh4embcZVlyZi0ZM3+q3fwsJCJCUlmXqGWrZApKSkICUlxVTbEydOoHv37mjXrh2mTZtmWXkAgB07dgAA0tLSuL87HA44HA7L/RIEQRBEuFKjC0mdOHEC3bp1Q8OGDfHuu+/i1KlTyMnJQU5OjqxNixYtsGnTJgDAwYMH8eqrr2Lr1q04fPgwfvjhBwwaNAg33XQT2rRpEyhRCYIgCCKsELIw3EFUIAIWRLls2TIcOHAABw4cQIMGDWS/CV4Tp9OJrKwsMcsiOjoay5cvx8SJE1FcXIyMjAz07dsXL774YqDEJAiCIIiww16TC0kNGTIEQ4YM0W3TuHFj2c5nZGRgzZo1gRKJIAiCIGoENtGJETxoLgyCIAiCCDO8payDJwMpEARBEAQRZnhLWdN03gRBEARBmKVSgwhmECUpEARBEAQRZngLSZEFgiAIgiAIk4hZGMGUIYjbJgiCIAjCB2y24GsQpEAQBEEQRJgRAvoDKRAEQRAEEW4IWRhuioEgCIIgCMIsVAeCIAiCIAjLCDEQVAeCIAiCIAjT1OjZOAmCIAiCCAyiBYIUCIIgCIIgzBL8qbRIgSAIgiCIsMMmlrKmGAiCIAiCIExiJxcGQRAEQRC+QlkYBEEQBEGYhupAEARBEARhGXE2ziDKQAoEQRAEQYQZXgsEuTAIgiAIgjBJjQ+ivOuuu9CwYUPExMQgLS0N//znP5Gdna27TmlpKYYPH446deqgVq1a6Nu3L3JzcwMpJkEQBEGEFTV+Ns7u3btjzpw5yMrKwvfff4+DBw+iX79+uuv8+9//xo8//oi5c+dizZo1yM7Oxj333BNIMQmCIAgirPCWsg6eChEZyM7//e9/i58bNWqE0aNHo0+fPnA6nYiKilK1LygowNSpU/HNN9/g5ptvBgBMmzYNLVu2xIYNG9CxY8dAiksQBEEQYUGNt0BIOXv2LGbOnInOnTtzlQcA2Lp1K5xOJ3r27Ckua9GiBRo2bIj169dXl6gEQRAEEeJ4NAi3uwYHUT7//POIj49HnTp1cPToUSxcuFCzbU5ODqKjo5GcnCxbXr9+feTk5HDXKSsrQ2FhoeyPIAiCIGoy9nC0QIwePRo2m033b9++fWL7Z599Ftu3b8fSpUsRERGBQYMG+dVnM2HCBCQlJYl/GRkZfuubIAiCIEIRWwj4MCzHQIwaNQpDhgzRbdO0aVPxc926dVG3bl1cfvnlaNmyJTIyMrBhwwZ06tRJtV5qairKy8uRn58vs0Lk5uYiNTWVu60xY8Zg5MiR4vfCwkJSIgiCIIgajRhEGUQZLCsQKSkpSElJ8WljbrcbgMftwKNdu3aIiorCihUr0LdvXwBAVlYWjh49ylU4AMDhcMDhcPgkD0EQBEGEI6FQSCpgWRgbN27E5s2bccMNN+CSSy7BwYMH8dJLL6FZs2aiMnDixAn06NEDX331Fdq3b4+kpCQMHToUI0eORO3atZGYmIgnn3wSnTp1ogwMgiAIgqgkFEpZB0yBiIuLw7x58zB27FgUFxcjLS0NvXr1wosvvihaDJxOJ7KyslBSUiKu98EHH8But6Nv374oKytDZmYmPv3000CJSRAEQRBhh2CBcAfRAmFjwbR/BIDCwkIkJSWhoKAAiYmJwRaHIAiCIPzO8XMluOGtVXBE2pH12m1+69fKM5TmwiAIgiCIMEPIwgirNE6CIAiCIIKLkIURTA2CFAiCIAiCCDO8ZSBqcCVKgiAIgiD8izCddxArWZMCQRAEQRDhRijMxkkKBEEQBEGEG8GvZE0KBEEQBEGEG2IhKXJhEARBEARhFpvN+zlYbgxSIAiCIAgizJDoD0GzQpACQRAEQRBhhl1iggiWF4MUCIIgCIIIM8iFQRAEQRCEZWwgCwRBEARBEFaRWSCCIwIpEARBEAQRZkhdGMGa0psUCIIgCIIIM6RBlEGTIdgCEARBEARhDUrjJAiCIAjCMrIsjCCFUZICQRAEQRBhhiwLgywQBEEQBEGYQW6BCA6kQBAEQRBEmFHjszDuuusuNGzYEDExMUhLS8M///lPZGdn667TrVs32Gw22d/jjz8eSDEJgiAIIqyo8S6M7t27Y86cOcjKysL333+PgwcPol+/fobrDRs2DCdPnhT/3n777UCKSRAEQRBhhSyLM0gKRGQgO//3v/8tfm7UqBFGjx6NPn36wOl0IioqSnO9uLg4pKamBlI0giAIgghb5PpDDXRhSDl79ixmzpyJzp076yoPADBz5kzUrVsXV111FcaMGYOSkpJqkpIgCIIgQh+bLfgujIBaIADg+eefxyeffIKSkhJ07NgRixYt0m3/wAMPoFGjRkhPT8fOnTvx/PPPIysrC/PmzeO2LysrQ1lZmfi9sLDQr/ITBEEQRKhhD8cgytGjR6uCHJV/+/btE9s/++yz2L59O5YuXYqIiAgMGjRId+rRRx99FJmZmWjdujUGDhyIr776CvPnz8fBgwe57SdMmICkpCTxLyMjw+ouEQRBEERYIbNABEsGZnEi8VOnTuHMmTO6bZo2bYro6GjV8uPHjyMjIwPr1q1Dp06dTG2vuLgYtWrVwuLFi5GZman6nWeByMjIQEFBARITE01tgyAIgiDCjcajfwIAbH6hJ1ISHH7ps7CwEElJSaaeoZZdGCkpKUhJSfFJMLfbDQCyB74RO3bsAACkpaVxf3c4HHA4/HPgCIIgCCJcsNk88Q81Lohy48aN+OSTT7Bjxw4cOXIEK1euxP33349mzZqJ1ocTJ06gRYsW2LRpEwDg4MGDePXVV7F161YcPnwYP/zwAwYNGoSbbroJbdq0CZSoBEEQBBF2iE6MmlYHIi4uDvPmzUOPHj1wxRVXYOjQoWjTpg3WrFkjWgycTieysrLELIvo6GgsX74ct956K1q0aIFRo0ahb9+++PHHHwMlJkEQBEGEJcKU3u6aloXRunVrrFy5UrdN48aNZQGVGRkZWLNmTaBEIgiCIIgagxBHWeNcGARBEARBBA6hnHWNLGVNEARBEESAEC0QwYEUCIIgCIIIQ4QgSovVGPwGKRAEQRAEEYYIQZTkwiAIgiAIwjRiECUpEARBEARBmEV0YVAWBkEQBEEQZrGRC4MgCIIgCKt4LRDBgRQIgiAIgghDvDEQ5MIgCIIgCMIktppayjrUcblccDqdwRaDsEBUVBQiIiKCLQZBEERIYAvybFoXnQLBGENOTg7y8/ODLQrhA8nJyUhNTRU1b4IgiIsVbyGp4Gz/olMgBOWhXr16iIuLowdRmMAYQ0lJCfLy8gAAaWlpQZaIIAgiuIhZGEHa/kWlQLhcLlF5qFOnTrDFISwSGxsLAMjLy0O9evXInUEQxEVNsC0QF1UQpRDzEBcXF2RJCF8Rzh3FrxAEcbHjDaKkLIxqg9wW4QudO4IgCA9UypogCIIgCMtQKWuCIAiCICxDFgjCNDk5OXjyySfRtGlTOBwOZGRk4M4778SKFSuqZftDhgxBnz59AtJ3t27d8PTTTwekb4IgiJqIDcF16V5UWRjhzOHDh9GlSxckJyfjnXfeQevWreF0OrFkyRIMHz4c+/btC7aIBEEQRDVir9QfKIiS0OX//u//YLPZsGnTJvTt2xeXX345rrzySowcORIbNmwAABw9ehR33303atWqhcTERNx3333Izc0V+xg3bhyuvvpqzJgxA40bN0ZSUhIGDBiAoqIisc13332H1q1bIzY2FnXq1EHPnj1RXFyMcePG4csvv8TChQths9lgs9mwevVqAMDzzz+Pyy+/HHFxcWjatCleeuklWZaE0XaHDBmCNWvW4MMPPxT7Pnz4cOAPKkEQRBhzUczGWVZWhquvvho2mw07duzQbVtaWorhw4ejTp06qFWrFvr27St7CPobxhhKyiuC8md2ApSzZ89i8eLFGD58OOLj41W/Jycnw+124+6778bZs2exZs0aLFu2DH///Tf69+8va3vw4EEsWLAAixYtwqJFi7BmzRq8+eabAICTJ0/i/vvvx8MPP4y9e/di9erVuOeee8AYwzPPPIP77rsPvXr1wsmTJ3Hy5El07twZAJCQkIDp06fjzz//xIcffoj//ve/+OCDD0xv98MPP0SnTp0wbNgwse+MjAzL55IgCOJipEYXknruueeQnp6OP/74w7Dtv//9b/z000+YO3cukpKSMGLECNxzzz1Yu3ZtQGS74HSh1ctLAtK3EX++kom4aONTcODAATDG0KJFC802K1aswK5du3Do0CHx4fvVV1/hyiuvxObNm3H99dcDANxuN6ZPn46EhAQAwD//+U+sWLECr7/+Ok6ePImKigrcc889aNSoEQCgdevW4jZiY2NRVlaG1NRU2bZffPFF8XPjxo3xzDPPYNasWXjuuefE5XrbTUpKQnR0NOLi4lR9EwRBEHxq/Gycv/zyC5YuXYp3333XsG1BQQGmTp2K999/HzfffDPatWuHadOmYd26daKZ/mLEzMWxd+9eZGRkyN7cW7VqheTkZOzdu1dc1rhxY/EhDnhKQgvlodu2bYsePXqgdevWuPfee/Hf//4X586dM9z27Nmz0aVLF6SmpqJWrVp48cUXcfToUVkbve0SBEEQ1hEViCBtP6AWiNzcXAwbNgwLFiwwVf1x69atcDqd6Nmzp7isRYsWaNiwIdavX4+OHTuq1ikrK0NZWZn4vbCw0JKMsVER+POVTEvr+IvYKHOlmC+77DLYbDa/BEpGRUXJvttsNrjdbgBAREQEli1bhnXr1mHp0qX4+OOP8cILL2Djxo1o0qQJt7/169dj4MCBGD9+PDIzM5GUlIRZs2bhvffeM71dgiAIwjr2mhoDwRjDkCFD8Pjjj+O6664ztU5OTg6io6ORnJwsW16/fn3k5ORw15kwYQKSkpLEP6u+c5vNhrjoyKD8ma2qWLt2bWRmZmLSpEkoLi5W/Z6fn4+WLVvi2LFjOHbsmLj8zz//RH5+Plq1amXpeHTp0gXjx4/H9u3bER0djfnz5wMAoqOj4XK5ZO3XrVuHRo0a4YUXXsB1112Hyy67DEeOHDG9PQFe3wRBEIQ23rkwwsSFMXr0aDFSXutv3759+Pjjj1FUVIQxY8YEQm6RMWPGoKCgQPyTPkBrEpMmTYLL5UL79u3x/fffY//+/di7dy8++ugjdOrUCT179kTr1q0xcOBAbNu2DZs2bcKgQYPQtWtX0wrcxo0b8cYbb2DLli04evQo5s2bh1OnTqFly5YAPG6InTt3IisrC6dPn4bT6cRll12Go0ePYtasWTh48CA++ugjUeGwQuPGjbFx40YcPnwYp0+fJusEQRCEAcGejdOyAjFq1Cjs3btX969p06ZYuXIl1q9fD4fDgcjISDRv3hwAcN1112Hw4MHcvlNTU1FeXo78/HzZ8tzcXM3gOofDgcTERNlfTaRp06bYtm0bunfvjlGjRuGqq67CLbfcghUrVuCzzz6DzWbDwoULcckll+Cmm25Cz5490bRpU8yePdv0NhITE/Hrr7/i9ttvx+WXX44XX3wR7733Hm677TYAwLBhw3DFFVfguuuuQ0pKCtauXYu77roL//73vzFixAhcffXVWLduHV566SXL+/fMM88gIiICrVq1QkpKiiqGgiAIgpAT7Nk4bSxAto+jR4/K4hGys7ORmZmJ7777Dh06dECDBg1U6xQUFCAlJQXffvst+vbtCwDIyspCixYtNGMglBQWFiIpKQkFBQUqZaK0tBSHDh1CkyZNEBMTU8U9JIIBnUOCIAgPN7+3Gn+fKsbsRzuiQ9M6fulT7xmqJGBBlA0bNpR9r1WrFgCgWbNmovJw4sQJ9OjRA1999RXat2+PpKQkDB06FCNHjkTt2rWRmJiIJ598Ep06dTKlPBAEQRDExYI9yC6MoJaydjqdyMrKQklJibjsgw8+gN1uR9++fVFWVobMzEx8+umnQZSSIAiCIEIPwYURrFLW1aZANG7cWBUpylsWExODSZMmYdKkSdUlGkEQBEGEHXe0SUdeUSnqJwbHnUuTaREEQRBEGPJUz8uCun2aTIsgCIIgCMtclAoE1RgIX+jcEQRBhAYXlQsjOjoadrsd2dnZSElJQXR0tOlqkERwYYyhvLwcp06dgt1uR3R0dLBFIgiCuKi5qBQIu92OJk2a4OTJk8jOzg62OIQPxMXFoWHDhrDbL0rjGUEQRMhwUSkQgMcK0bBhQ1RUVNDcC2FGREQEIiPNzyFCEARBBI6LToEAPPXDo6KiVDNEEgRBEARhDrIDEwRBEARhGVIgCIIgCIKwDCkQBEEQBEFYpsbFQAilsaUzgRIEQRAEYYzw7DQzUXeNUyCKiooAABkZGUGWhCAIgiDCk6KiIiQlJem2sTEzakYY4Xa7kZ2djYSEBL+k+xUWFiIjIwPHjh0znBu9JnAx7S/ta82E9rXmcjHtb7D2lTGGoqIipKenG9bbqXEWCLvdjgYNGvi938TExBp/wUq5mPaX9rVmQvtac7mY9jcY+2pkeRCgIEqCIAiCICxDCgRBEARBEJYhBcIAh8OBsWPHwuFwBFuUauFi2l/a15oJ7WvN5WLa33DY1xoXREkQBEEQROAhCwRBEARBEJYhBYIgCIIgCMuQAkEQBEEQhGVIgSAIgiAIwjKkQACYNGkSGjdujJiYGHTo0AGbNm3SbT937ly0aNECMTExaN26NX7++edqkrRqTJgwAddffz0SEhJQr1499OnTB1lZWbrrTJ8+HTabTfYXExNTTRL7zrhx41Ryt2jRQnedcD2vjRs3Vu2rzWbD8OHDue3D6Zz++uuvuPPOO5Geng6bzYYFCxbIfmeM4eWXX0ZaWhpiY2PRs2dP7N+/37Bfq/d8daG3v06nE88//zxat26N+Ph4pKenY9CgQcjOztbt05d7oTowOrdDhgxRyd2rVy/DfkPx3BrtK+/+tdlseOeddzT7DIXzetErELNnz8bIkSMxduxYbNu2DW3btkVmZiby8vK47detW4f7778fQ4cOxfbt29GnTx/06dMHu3fvrmbJrbNmzRoMHz4cGzZswLJly+B0OnHrrbeiuLhYd73ExEScPHlS/Dty5Eg1SVw1rrzySpncv//+u2bbcD6vmzdvlu3nsmXLAAD33nuv5jrhck6Li4vRtm1bTJo0ifv722+/jY8++giff/45Nm7ciPj4eGRmZqK0tFSzT6v3fHWit78lJSXYtm0bXnrpJWzbtg3z5s1DVlYW7rrrLsN+rdwL1YXRuQWAXr16yeT+9ttvdfsM1XNrtK/SfTx58iS++OIL2Gw29O3bV7ffoJ9XdpHTvn17Nnz4cPG7y+Vi6enpbMKECdz29913H+vdu7dsWYcOHdhjjz0WUDkDQV5eHgPA1qxZo9lm2rRpLCkpqfqE8hNjx45lbdu2Nd2+Jp3Xp556ijVr1oy53W7u7+F6TgGw+fPni9/dbjdLTU1l77zzjrgsPz+fORwO9u2332r2Y/WeDxbK/eWxadMmBoAdOXJEs43VeyEY8PZ18ODB7O6777bUTzicWzPn9e6772Y333yzbptQOK8XtQWivLwcW7duRc+ePcVldrsdPXv2xPr167nrrF+/XtYeADIzMzXbhzIFBQUAgNq1a+u2O3/+PBo1aoSMjAzcfffd2LNnT3WIV2X279+P9PR0NG3aFAMHDsTRo0c129aU81peXo6vv/4aDz/8sO5kcuF6TqUcOnQIOTk5svOWlJSEDh06aJ43X+75UKagoAA2mw3Jycm67azcC6HE6tWrUa9ePVxxxRV44okncObMGc22NeXc5ubm4qeffsLQoUMN2wb7vF7UCsTp06fhcrlQv3592fL69esjJyeHu05OTo6l9qGK2+3G008/jS5duuCqq67SbHfFFVfgiy++wMKFC/H111/D7Xajc+fOOH78eDVKa50OHTpg+vTpWLx4MT777DMcOnQIN954ozjdu5Kacl4XLFiA/Px8DBkyRLNNuJ5TJcK5sXLefLnnQ5XS0lI8//zzuP/++3UnW7J6L4QKvXr1wldffYUVK1bgrbfewpo1a3DbbbfB5XJx29eUc/vll18iISEB99xzj267UDivNW42TsIcw4cPx+7duw19Zp06dUKnTp3E7507d0bLli0xefJkvPrqq4EW02duu+028XObNm3QoUMHNGrUCHPmzDGl2YcrU6dOxW233Yb09HTNNuF6TgkvTqcT9913Hxhj+Oyzz3Tbhuu9MGDAAPFz69at0aZNGzRr1gyrV69Gjx49gihZYPniiy8wcOBAw8DmUDivF7UFom7duoiIiEBubq5seW5uLlJTU7nrpKamWmofiowYMQKLFi3CqlWrLE99HhUVhWuuuQYHDhwIkHSBITk5GZdffrmm3DXhvB45cgTLly/HI488Ymm9cD2nwrmxct58uedDDUF5OHLkCJYtW2Z5qmejeyFUadq0KerWraspd004t7/99huysrIs38NAcM7rRa1AREdHo127dlixYoW4zO12Y8WKFbI3NCmdOnWStQeAZcuWabYPJRhjGDFiBObPn4+VK1eiSZMmlvtwuVzYtWsX0tLSAiBh4Dh//jwOHjyoKXc4n1eBadOmoV69eujdu7el9cL1nDZp0gSpqamy81ZYWIiNGzdqnjdf7vlQQlAe9u/fj+XLl6NOnTqW+zC6F0KV48eP48yZM5pyh/u5BTwWxHbt2qFt27aW1w3KeQ1qCGcIMGvWLOZwONj06dPZn3/+yR599FGWnJzMcnJyGGOM/fOf/2SjR48W269du5ZFRkayd999l+3du5eNHTuWRUVFsV27dgVrF0zzxBNPsKSkJLZ69Wp28uRJ8a+kpERso9zf8ePHsyVLlrCDBw+yrVu3sgEDBrCYmBi2Z8+eYOyCaUaNGsVWr17NDh06xNauXct69uzJ6taty/Ly8hhjNeu8MuaJNm/YsCF7/vnnVb+F8zktKipi27dvZ9u3b2cA2Pvvv8+2b98uZh28+eabLDk5mS1cuJDt3LmT3X333axJkybswoULYh8333wz+/jjj8XvRvd8MNHb3/LycnbXXXexBg0asB07dsju4bKyMrEP5f4a3QvBQm9fi4qK2DPPPMPWr1/PDh06xJYvX86uvfZadtlll7HS0lKxj3A5t0bXMWOMFRQUsLi4OPbZZ59x+wjF83rRKxCMMfbxxx+zhg0bsujoaNa+fXu2YcMG8beuXbuywYMHy9rPmTOHXX755Sw6OppdeeWV7KeffqpmiX0DAPdv2rRpYhvl/j799NPisalfvz67/fbb2bZt26pfeIv079+fpaWlsejoaHbppZey/v37swMHDoi/16TzyhhjS5YsYQBYVlaW6rdwPqerVq3iXrPC/rjdbvbSSy+x+vXrM4fDwXr06KE6Bo0aNWJjx46VLdO754OJ3v4eOnRI8x5etWqV2Idyf43uhWCht68lJSXs1ltvZSkpKSwqKoo1atSIDRs2TKUIhMu5NbqOGWNs8uTJLDY2luXn53P7CMXzStN5EwRBEARhmYs6BoIgCIIgCN8gBYIgCIIgCMuQAkEQBEEQhGVIgSAIgiAIwjKkQBAEQRAEYRlSIAiCIAiCsAwpEARBEARBWIYUCIIgCIIgLEMKBEEQBEEQliEFgiAIgiAIy5ACQRAEQRCEZUiBIAiCIAjCMqRAEEQIMX36dNhsNhw+fDgo2x8yZAgaN24clG3XBOj4ERcTpEAQRAD49NNPYbPZ0KFDh2CLoiI7Oxvjxo3Djh07gi2KJd544w0sWLAg2GIQBFEJKRAEEQBmzpyJxo0bY9OmTThw4ECwxZGRnZ2N8ePHcxWI//73v8jKyqp+oUwQDgpEKB8/gvA3pEAQhJ85dOgQ1q1bh/fffx8pKSmYOXNmsEUyTVRUFBwOR7DFCFvo+BEXE6RAEISfmTlzJi655BL07t0b/fr101Qg9uzZg5tvvhmxsbFo0KABXnvtNbjdblW7hQsXonfv3khPT4fD4UCzZs3w6quvwuVyydp169YNV111FbZu3YrOnTsjNjYWTZo0weeffy62Wb16Na6//noAwEMPPQSbzQabzYbp06cDkPvwnU4nateujYceekglU2FhIWJiYvDMM8+Iy8rKyjB27Fg0b94cDocDGRkZeO6551BWVmZ4zPbv34++ffsiNTUVMTExaNCgAQYMGICCggIAgM1mQ3FxMb788ktR5iFDhojrnzhxAg8//DDq168Ph8OBK6+8El988YVsG6tXr4bNZsPs2bPxn//8B6mpqYiPj8ddd92FY8eOGcpYVFSEp59+Go0bN4bD4UC9evVwyy23YNu2bWIbZQxEt27dRHmVf8IxB4D8/Hw8/fTTyMjIgMPhQPPmzfHWW29xrweCCBUigy0AQdQ0Zs6ciXvuuQfR0dG4//778dlnn2Hz5s3igxsAcnJy0L17d1RUVGD06NGIj4/HlClTEBsbq+pv+vTpqFWrFkaOHIlatWph5cqVePnll1FYWIh33nlH1vbcuXO4/fbbcd999+H+++/HnDlz8MQTTyA6OhoPP/wwWrZsiVdeeQUvv/wyHn30Udx4440AgM6dO6u2GxUVhX/84x+YN28eJk+ejOjoaPG3BQsWoKysDAMGDAAAuN1u3HXXXfj999/x6KOPomXLlti1axc++OAD/PXXX7quh/LycmRmZqKsrAxPPvkkUlNTceLECSxatAj5+flISkrCjBkz8Mgjj6B9+/Z49NFHAQDNmjUDAOTm5qJjx46w2WwYMWIEUlJS8Msvv2Do0KEoLCzE008/Ldve66+/DpvNhueffx55eXmYOHEievbsiR07dnCPv8Djjz+O7777DiNGjECrVq1w5swZ/P7779i7dy+uvfZa7jovvPACHnnkEdmyr7/+GkuWLEG9evUAACUlJejatStOnDiBxx57DA0bNsS6deswZswYnDx5EhMnTtSUiSCCCiMIwm9s2bKFAWDLli1jjDHmdrtZgwYN2FNPPSVr9/TTTzMAbOPGjeKyvLw8lpSUxACwQ4cOictLSkpU23nsscdYXFwcKy0tFZd17dqVAWDvvfeeuKysrIxdffXVrF69eqy8vJwxxtjmzZsZADZt2jRVv4MHD2aNGjUSvy9ZsoQBYD/++KOs3e23386aNm0qfp8xYwaz2+3st99+k7X7/PPPGQC2du1a1bYEtm/fzgCwuXPnarZhjLH4+Hg2ePBg1fKhQ4eytLQ0dvr0adnyAQMGsKSkJPH4rVq1igFgl156KSssLBTbzZkzhwFgH374oe72k5KS2PDhw3XbKI+fkrVr17KoqCj28MMPi8teffVVFh8fz/766y9Z29GjR7OIiAh29OhR3W0SRLAgFwZB+JGZM2eifv366N69OwCP6b1///6YNWuWzOXw888/o2PHjmjfvr24LCUlBQMHDlT1KX0rLioqwunTp3HjjTeipKQE+/btk7WNjIzEY489Jn6Pjo7GY489hry8PGzdutXy/tx8882oW7cuZs+eLS47d+4cli1bhv79+4vL5s6di5YtW6JFixY4ffq0+HfzzTcDAFatWqW5jaSkJADAkiVLUFJSYkk+xhi+//573HnnnWCMybadmZmJgoICmYsBAAYNGoSEhATxe79+/ZCWloaff/5Zd1vJycnYuHEjsrOzLckokJOTg379+uHqq6/Gp59+Ki6fO3cubrzxRlxyySUy+Xv27AmXy4Vff/3Vp+0RRKAhBYIg/ITL5cKsWbPQvXt3HDp0CAcOHMCBAwfQoUMH5ObmYsWKFWLbI0eO4LLLLlP1ccUVV6iW7dmzB//4xz+QlJSExMREpKSk4MEHHwQAMUZAID09HfHx8bJll19+OQD4VFsiMjISffv2xcKFC8VYhnnz5sHpdMoUiP3792PPnj1ISUmR/QnbzsvL09xGkyZNMHLkSPzvf/9D3bp1kZmZiUmTJqn2jcepU6eQn5+PKVOmqLYtxG4ot6087jabDc2bNzc8Pm+//TZ2796NjIwMtG/fHuPGjcPff/9tKCMAVFRU4L777oPL5cK8efNkgZb79+/H4sWLVfL37NmTKz9BhAoUA0EQfmLlypU4efIkZs2ahVmzZql+nzlzJm699VZLfebn56Nr165ITEzEK6+8gmbNmiEmJgbbtm3D888/Xy1BdgMGDMDkyZPxyy+/oE+fPpgzZw5atGiBtm3bim3cbjdat26N999/n9tHRkaG7jbee+89DBkyBAsXLsTSpUvxr3/9CxMmTMCGDRvQoEEDzfWE/X/wwQcxePBgbps2bdoY7aIp7rvvPtx4442YP38+li5dinfeeQdvvfUW5s2bh9tuu0133WeffRbr16/H8uXLVfvjdrtxyy234LnnnuOuKyhhBBFqkAJBEH5i5syZqFevHiZNmqT6bd68eZg/fz4+//xzxMbGolGjRti/f7+qnbKGwOrVq3HmzBnMmzcPN910k7j80KFDXBmys7NRXFwss0L89ddfACBmB9hsNkv7ddNNNyEtLQ2zZ8/GDTfcgJUrV+KFF16QtWnWrBn++OMP9OjRw3L/Aq1bt0br1q3x4osvYt26dejSpQs+//xzvPbaa5pyp6SkICEhAS6XS3xjN0J53BljOHDggClFIy0tDf/3f/+H//u//0NeXh6uvfZavP7667oKxKxZszBx4kRMnDgRXbt2Vf3erFkznD9/3rT8BBEqkAuDIPzAhQsXMG/ePNxxxx3o16+f6m/EiBEoKirCDz/8AAC4/fbbsWHDBmzatEns49SpU6qUz4iICACeh5xAeXm5zIcupaKiApMnT5a1nTx5MlJSUtCuXTsAEJWL/Px8U/tmt9vRr18//Pjjj5gxYwYqKipk7gvA83Z+4sQJ/Pe//+Uem+LiYs3+CwsLUVFRIVvWunVr2O12WQpofHy8SuaIiAj07dsX33//PXbv3q3q+9SpU6plX331FYqKisTv3333HU6ePKmrBLhcLpVLpV69ekhPT9dNU929ezceeeQRPPjgg3jqqae4be677z6sX78eS5YsUf2Wn5+vOjYEESqQBYIg/MAPP/yAoqIi3HXXXdzfO3bsKBaV6t+/P5577jnMmDEDvXr1wlNPPSWmcTZq1Ag7d+4U1+vcuTMuueQSDB48GP/6179gs9kwY8YMmUIhJT09HW+99RYOHz6Myy+/HLNnz8aOHTswZcoUREVFAfC88SYnJ+Pzzz9HQkIC4uPj0aFDBzRp0kRz//r374+PP/4YY8eORevWrdGyZUvZ7//85z8xZ84cPP7441i1ahW6dOkCl8uFffv2Yc6cOViyZAmuu+46bt8rV67EiBEjcO+99+Lyyy9HRUUFZsyYISoHAu3atcPy5cvx/vvvIz09HU2aNEGHDh3w5ptvYtWqVejQoQOGDRuGVq1a4ezZs9i2bRuWL1+Os2fPyrZXu3Zt3HDDDXjooYeQm5uLiRMnonnz5hg2bJjm/hcVFaFBgwbo168f2rZti1q1amH58uXYvHkz3nvvPc31hDiMm266CV9//bXst86dO6Np06Z49tln8cMPP+COO+7AkCFD0K5dOxQXF2PXrl347rvvcPjwYdStW1dzGwQRNIKaA0IQNYQ777yTxcTEsOLiYs02Q4YMYVFRUWK64c6dO1nXrl1ZTEwMu/TSS9mrr77Kpk6dqkrjXLt2LevYsSOLjY1l6enp7LnnnhPTK1etWiW269q1K7vyyivZli1bWKdOnVhMTAxr1KgR++STT1SyLFy4kLVq1YpFRkbKUjq10hDdbjfLyMhgANhrr73G3b/y8nL21ltvsSuvvJI5HA52ySWXsHbt2rHx48ezgoICzePy999/s4cffpg1a9aMxcTEsNq1a7Pu3buz5cuXy9rt27eP3XTTTSw2NpYBkKV05ubmsuHDh7OMjAwWFRXFUlNTWY8ePdiUKVPENkIa57fffsvGjBnD6tWrx2JjY1nv3r3ZkSNHNOVjzJMO++yzz7K2bduyhIQEFh8fz9q2bcs+/fRTWTvl8WvUqBEDwP2TptEWFRWxMWPGsObNm7Po6GhWt25d1rlzZ/buu++K6bcEEWrYGNN4lSEIIqzo1q0bTp8+zTXlE554ku7du2Pu3Lno169fsMUhiLCHYiAIgiAIgrAMKRAEQRAEQViGFAiCIAiCICxDMRAEQRAEQViGLBAEQRAEQViGFAiCIAiCICxT4wpJud1uZGdnIyEhweeSugRBEARxMcIYQ1FREdLT02G369sYapwCkZ2dbThxD0EQBEEQ2hw7dkx3IjugBioQCQkJADw7n5iYGGRpCIIgCCJ8KCwsREZGhvgs1aPGKRCC2yIxMZEUCIIgCILwATMhABRESRAEQRCEZUiBIAiCIAjCMqRAEARBEARhmRoXA0EQNQmXywWn0xlsMQiTREVFISIiIthiEES1QAoEQYQgjDHk5OQgPz8/2KIQFklOTkZqairVoSFqPKRAEEQIIigP9erVQ1xcHD2MwgDGGEpKSpCXlwcASEtLC7JEBBFYSIEgiBDD5XKJykOdOnWCLQ5hgdjYWABAXl4e6tWrR+4MokZDQZQEEWIIMQ9xcXFBloTwBeG8UewKUdMhBSIMoBnXL07IbRGe0HkjLhZIgQhx8opK0f6NFZjwy95gi0IQBEEQIqRAhDifrT6IU0VlmLzm72CLQhABYdy4cbj66qurZVtDhgxBnz59qmVbNZVfdp3ETztPBqTv82UVeGzGFvz4R3ZA+if8CykQIY7bTe6LQFLqdOH0+bJgi1HjWL9+PSIiItC7d++gbP/w4cOw2WzYsWOHbPmHH36I6dOnB0WmmsCFcheemLkNw7/ZhsJS/8d4fLrqAJbsycWT3273e9+E/yEFIsQh/SGwdH1nFa57bTlO5F8Itig1iqlTp+LJJ5/Er7/+iuzs0HmbTEpKQnJycrDFCFvKKlzi59Jyl05L3yBlPrwgBSLEcVMAZUDJLfQMWL/9dSrIktQczp8/j9mzZ+OJJ55A7969VW/8b775JurXr4+EhAQMHToUpaWlst83b96MW265BXXr1kVSUhK6du2Kbdu2ydrYbDZ89tlnuO222xAbG4umTZviu+++E39v0qQJAOCaa66BzWZDt27dAMhdGFOmTEF6ejrcbres77vvvhsPP/yw+H3hwoW49tprERMTg6ZNm2L8+PGoqKioyiGqGQQgVpSGu/CCFIgQhywQ1UOoH2bGGErKK6r9z5cMoDlz5qBFixa44oor8OCDD+KLL74Q+5kzZw7GjRuHN954A1u2bEFaWho+/fRT2fpFRUUYPHgwfv/9d2zYsAGXXXYZbr/9dhQVFcnavfTSS+jbty/++OMPDBw4EAMGDMDevZ5g402bNgEAli9fjpMnT2LevHkqOe+9916cOXMGq1atEpedPXsWixcvxsCBAwEAv/32GwYNGoSnnnoKf/75JyZPnozp06fj9ddft3xcagKBfsDTeBdeUCGpEIdiIKqHUH/zueB0odXLS6p9u3++kom4aGvDxNSpU/Hggw8CAHr16oWCggKsWbMG3bp1w8SJEzF06FAMHToUAPDaa69h+fLlMivEzTffLOtvypQpSE5Oxpo1a3DHHXeIy++991488sgjAIBXX30Vy5Ytw8cff4xPP/0UKSkpAIA6deogNTWVK+cll1yC2267Dd988w169OgBAPjuu+9Qt25ddO/eHQAwfvx4jB49GoMHDwYANG3aFK+++iqee+45jB071tJxqQ6+33ocaw+cxlv92iAqIrDvh7YAmCAoZT28IAtEiOOiG6paYCFvgwgPsrKysGnTJtx///0AgMjISPTv3x9Tp04FAOzduxcdOnSQrdOpUyfZ99zcXAwbNgyXXXYZkpKSkJiYiPPnz+Po0aO663Xq1Em0QJhl4MCB+P7771FW5nFlzZw5EwMGDIDd7hka//jjD7zyyiuoVauW+Dds2DCcPHkSJSUllrZVHYya+wfmbT+B77ceD0j/gb5L6C4ML8gCEeJQDAQBALFREfjzlcygbNcKU6dORUVFBdLT08VljDE4HA588sknpvoYPHgwzpw5gw8//BCNGjWCw+FAp06dUF5ebkkWM9x5551gjOGnn37C9ddfj99++w0ffPCB+Pv58+cxfvx43HPPPap1Y2Ji/C6PvzhXEvgqmIGol0XjXXhBCkSIQ/dT9RDqx9lms1l2JVQ3FRUV+Oqrr/Dee+/h1ltvlf3Wp08ffPvtt2jZsiU2btyIQYMGib9t2LBB1nbt2rX49NNPcfvttwMAjh07htOnT6u2t2HDBlU/11xzDQAgOjoagGdeET1iYmJwzz33YObMmThw4ACuuOIKXHvtteLv1157LbKystC8eXMzhyBkCFQxzEC7GEL9PiTkhPaIFGIwxqq9TC1p5NUDHeWqs2jRIpw7dw5Dhw5FUlKS7Le+ffti6tSpeOaZZzBkyBBcd9116NKlC2bOnIk9e/agadOmYtvLLrsMM2bMwHXXXYfCwkI8++yz4iRVUubOnYvrrrsON9xwA2bOnIlNmzaJrpJ69eohNjYWixcvRoMGDRATE6OSSWDgwIG44447sGfPHjF2Q+Dll1/GHXfcgYYNG6Jfv36w2+34448/sHv3brz22mtVPWRhTSBGQhrvwguKgTBJqdOFnu+vwb9n76jW7booiLJ6oIGrykydOhU9e/bkPqj79u2LLVu2oGXLlnjppZfw3HPPoV27djhy5AieeOIJVT/nzp3Dtddei3/+85/417/+hXr16qn6HD9+PGbNmoU2bdrgq6++wrfffotWrVoB8MRefPTRR5g8eTLS09Nx9913a8p98803o3bt2sjKysIDDzwg+y0zMxOLFi3C0qVLcf3116Njx4744IMP0KhRI18OUbURqNccioEgpJAFwiSrs07h4KliHDxVjA/6X11t29XSyJ0ud8CjrC8maOCqOj/++KPmb+3btxfN323atMF//vMf2e9vvfWW+Pmaa67B5s2bZb/369dP1Wd6ejqWLl2quc1HHnlEzNIQ4FWhtNvtusWuMjMzkZlZ/fEnoYh0OAqENZayMMILegKZJjgXtqLGDQDg3SVZaPHSYuw9WVj9AtVQaNwiCGOk2UoBcWFwxjsidCEFwiTBesDwLBCfrDoAl5vh7cX7giBRzYTefAgi+FA6dXhBLowQRy8EoroDOmsyNGyFF6Tw6ROwoSHAh51Oa3hBFogQh6KSq4dgH+a8olI8+tUW/EpzchB+IBBVIpUE4pahmPHwghQIkwTruqYsjOoh2Ed53A97sPTPXAz6YlOQJSEIbQJ/nwT7TiSsQAqESUIpBoKoeWTnl6qWKWeJJMKDi+W8BcKNRO9L4QXFQJgkWME9pD9UD6HkU4+OjhZTC1NSUhAdHU3xLmEAYwzl5eU4deoU7Ha7WA2zJhH42ThD5z4kjCEFIsTRu6HokVIzsdvtaNKkCU6ePKlbn4AITeLi4tCwYUNxQq5gwNM3/VFJV/oiFYhHPekP4QUpECEOxUBUD6E2cEVHR6Nhw4aoqKgwnM+BCB0iIiIQGRkZchajH/7Ixrgf9mDyP9vh+sa1fe6HLBCElJBXIN58802MGTMGTz31FCZOnBg0OSgGomYTivnnNpsNUVFRiIqKCrYoRJjzr2+3AwCGTt+MneN8r6opvUtoaCJCOohy8+bNmDx5Mtq0aRNsUYL2eCEDRPVAgyERyrj9NBBU9TqvSbNxni+rwJI9OSh1koXPV0JWgTh//jwGDhyI//73v7jkkkuCLU7QguzIAlE9BPsoB3v7ROjyx7F8XP3KUny94UiV+6rqdSYdjgJhtavO8W74zG14bMZWvLxwd7Vts6YRsgrE8OHD0bt3b/Ts2VO3XVlZGQoLC2V/NQn9SpTVJ0dNJ+h6WtAFIEKVp2ZtR2FpBV5cYP5BF2oxGGapTgViTWXRtjlbjlfbNmsaIalAzJo1C9u2bcOECRMM206YMAFJSUniX0ZGRjVIWH1Uh+Xj2NkSv5lIq8Ifx/Lx3Hd/IK9IXRMh0IRiDARB+ErAKlnLgiAC3D8BACgqdWJVVh6crtCrLxJyCsSxY8fw1FNPYebMmYiJiTFsP2bMGBQUFIh/x44dqwYpq49AZ2F8s/Eobnx7FZ77fmdAt2OGuyetxZwtxzH6+13BFoWoJopKnVj2Zy7KKsgPrUUoWBOEF5lAK9qkQKgZMm0zHpq2GR+t2B9sUVSEnAKxdetW5OXl4dprr0VkZCQiIyOxZs0afPTRR4iMjFSltDkcDiQmJsr+AkGwLuxAKxATl/8FAPhua+iY8Q6eOl/t26SBKzDsOl6Af07diD3ZBdzfH/1qK4Z9tQVv/hJ+M8tWV1xUsNWHC+Uu3PLBr/jP/F2KGAj/U5MsgXO3HMML83dV2bq79cg5AMCcLaH3chxyCkSPHj2wa9cu7NixQ/y77rrrMHDgQOzYsQMRERFBkaumVqK0h8DbjZJQlInwjXs+W4vf9p/GgCkbuL+v//sMAGD25tAbHPUYu3A3bnpnFQpLnYHfmB9vB8YYGGP48Y9sHD5drNlu94kCPPH1Vvx96jx+3nUSB/LO45uNRwM+CtYkRf7Z73Zi5sajWPpnjl/6CwEvs4qQqwORkJCAq666SrYsPj4ederUUS2/GAh0UFEoPquDIVKwS1mH2tjAGIObARH2qp0Np8uzZ0WlFf4Qy++43Ayrs/JwbcNLcEm8+dLTX673ZETM23ocQ7o0CZR4PqN3X/+06ySerKwLcfjN3tw2d3z8OwBgX04RHrupqbhcep8E4papiVln+SX+UTKDPUbxCDkLBCFH/4aq+qM2JN/2gyBSCN6bQeW+yetx09urUF5RPYFbwTr+09YewtAvt+DuSWt9Wr+qCpYZfNmC1jo2mw1bDp8z3c+h08WyN99An6ZQfMuuKv7apVCsShxyFggeq1evDrYIQaxEGZztBpOQVGouMjZXPmT2nixE24zk4AoTQBbtPAkAOHq2xKf1I4I434UeC3Zkcy0jvrzFFpd5rUeBrgMRim/ZoUIoPgtC8+oPQWpqEGUojn/V8FKnIpj35ofL92PncX6Q4cVCuAbPRUZUgwXCB4V6x7F87D7Bv6asdicEWgPVUIkyoL0HB38dslB074Tg4yM0CV4p6wArECH4th8MmYJ1cx48dR4fSAboUKO6jkqw3q6qutlIHW336JkSFPjJ/+0LvlpVlBSXezPfAj0XRgg+I6uMv5RjIZvjnSX70GvirzLLULAgBSLECfQNFXrqQ3AI1sAVqsGFhDm0YiBO5F/ATe+sQttXllZ5G2bv0UOKrIpAWPKU90lZhQuLd59EwQX/KEqh+JYdKghK9qRVB7EvpygkMpdIgfCB6vTT6d1Q/nhRD0ULRDAK59CwVTVcbla1fPdgnQAf7mWpWzFSwwe4/aj5QEUjzN4Or/y4R7Eef0VbFV4bpG/TDMCbv+zD419vw9Dpm33uU9Z/DbwRA+XCCIXKlKRAmCTQ6UtaBDzyNvT0h6DEQASLmhA0VuFy4+b3VuPOT34Pu/3xRVrpwK0VA6GlWASSCsVYwXs58OdkWgDwXeU8EluO+EdhIguENqF4aMIiCyPUUJ7HCT/vRXSkHaNuvcLv27oYYyCCIlI13527TxRg7pZj6HZFvWrdbiA4crYER854/O1uBvgSVxhOQZSfrzkoftaKgYiqhuBKJcqxQjuV0/dtyLIwGFMpLVUlFB+SVcVvaZyKgxMKh4oUCJNonay8olJM/vVvAMAT3ZohLtq/hzTQBohQfNuvionVV6r7ZhQK9azYl1fNW7aGGYuC8qESkmYtDaSyF5U6kRATZbjOxOXeOQm0YiAiI/xngTB7P7gVFm0tRaEqD2mloudvC2k4KZLVjVJBDAVli1wYZlENkh4qXN7PTpf/z2jAK1EGcbBfuOMEHv1qiyqaOChpnEG6GY+fu1Cl9UPDZcA4nyz2EAK7cePbqyyvoyV2lOQirupD1qzFQDlWBMK6KFcWAadSa6kieoeKMYYZ6w+bji/58Y9svL80K/j3iJ+2H+zd4EEKhA9Iz6P0DSQQU2Lr9emP4SGYHoynZu3A0j9z8cXvh2TLgxNEqX/uluzJwSNfbsbZ4vJqksiY/blFuPbVZfhvpQUsFPB1kAtaDKVkyz6VHNYQXGqBqK5gN9XLhsZt5M/by98PNb2H/ZI9uXhp4R7849N1pvp68tvt+GjlAaw/eMZf4vlEoK7tULDWkAJhEln0seS8SW/GQFgLAu/CCL65WTkhUQiIpOKxGVuxfG8e3l2aFWxRRF5auBvnSpx4/ee9ptcpLHUaKkFVeWO72ILgtPZXGlxZVQXCrEKtHCu4QZRVPD2BPr16/R/IK/Kpz1Pny3yUJrQJhVuNFAgfkCoT0ps0EA/7i6ESZZTCXxxlUqijZ0rwn/m7VPnvvmD2Zjx7PnQsEL7QZtxSXPvqMt0iNFYHJn8MZMEyMwfqgSq9hgPh2uShHCsCUgdC40XKX+inrfu2Q2bkDOT1FwoP+kARAo+P8EDrIoiQXNSBuAhrcgyEgFKBcESZuyyHTNuEbzYexf0aU0VbwayeFhGE6HotqnLu/FWlEAh8dcJAUlV5te5P6bOuws9xAmZl4V0fNlvVXJ+BtogG6/KprkqoQY/H8DOkQJhEa5CUuzACsN1Al4EIgedhdKRCgYiMMLXe35WWh5zC0irLYNafGBEKB8wP6CmmVbnkfPXLBnpYdbsZLkhKMvtru2bu+SoHUZpsZ8YCUdXxRLqNwEym5fcuTckZSEuvoDSMmvMHbvngV5Q61ddhVfoNJqRA+JHAxEDomfSq3n8wAhaVKHPmY0xaIPyKyVOnN/dBdeOvfH7ddhb7DcUZAwGg/5T1aPnyYpz2uz+cv8PS41tRRReG+SwM5Xr+v1YDbRENxHhnxgBUHbE73287jgN557HST6nbIaA/kAJhFmX6Eo9AXITK4iFVgTGGsgq59hus56FUe1a5MExaIIKBPYQUiEBh9c1GXQfCl236tJpphOnJl/2Z69d+tRQm6Vuvv4stacqi2I7WA7cqeoXMAlGF3VIGTpvp01eXnRkxA2qBUHyvruuhOiAFwge0AokC4er058D6yJdb0HrsUlkUfrCyMKQ3UVSEXTb4BcMCYfYw1xQLROBcGOGFVYXnZIG8boeZ1V1VHBjMnmfly0YgrlR/vCS9/tOfaDNuKVZVUxE1M+c4kBYIZdf+SvcPhXuNFAgfqE4LhD9ZsS8P5S43ftqZLS4L1uOwrMI7qEZH2FEqsYzERFW/BcLsg8TXh/a6A6cxas4fQZve+djZEszceET87s9LVaZQB39+H12qut+dJqyUfde652UujCrHQJitRKm0QATAhSE5v77u1X9/OwQAeIOTfuyrC0Pv/jVzzqspzhWA/6wdofC4oVLWJjETiBPqCoSA9PoNVghEmSSQKDLChlKn9w4OigXCbEyAj6f4gf9tBODZt9f/0dq3ThRYMen2eG8NyiX1CPR2w5+ljkMNf8tnpreqxkCYRTVXAudEMrAqKRb+dKnyxNBVIDSW/7TzJF5csAuTBl6Lzs3qqn43FUQZSAtENW6ruiELhElkfl7wP4eia2vn8XxMXP6XLPZBepMGK4hSaoEAgAsShSLYqaXDvtqCAVPWcwdgq0piqdOF+z5fL36vaulqXylXFDMyux/mcuittQ8lql4HQsMCIflcbTEQ1TDZklsWA+F/y4ovh2r4N9twrsSJwV9s4v5ups/AujDkffvLAjFjw2EUXAiORVOALBAmkSkNGhdbdVsgzDxo7/pkLQDtglfBcumXSxQIN4MsxS4Yb7FCXEiFyy0G2h0+U4ImdeNxWFKoyuopnrvlGDYdPus3Of2Frsm3Csc/zPSHKl9roRQDodxMIIYjrYef1qRievD2SzeI0mATWgW7zLkwqu/K9ZcCcfp8OZ6d+wemDLrOL/35AlkgfEDr9IeyC2PvyULxs/ThEawgSml5X8aYLDc6GIdx3vYTOHpGXlzJzRjW/HUK3d5dLS6zKlog3xAClcZZlUqUoXwPBALtGAhJFkZV0zh9lEVLOarKHS81v0t790WB4KPnwvA1CyO4Lgwl/rxHlvo5q8gqpED4gFbaWnUG4lhFqvXKXRjBkEYOY5ArEEGSY+mfOaptz958VPbd6r2vdNUEimMWK0v68xhLr6dQ1x+U8gWqlLVfXRgmb1KVLAE4F1pv6v7KTgpWMb7AFpKqvm1VN6RAmMXPqUCrs/Lw9Kzt1ebDksomd2EEX4NwMyYLogzmQ0jpz49QzMth1e9bXQqE1amorR5jl5vhvsnr8dx3f+j2FepBlP7GzD0frLx/3larXIlSQ1k0Y4FwutyyeWt48Vd695evQ5WZXQ7kmKO8J/ypQAR7+CYFwiRM84sXKwrEkGmbsWBHNj5Y9leV5DKL9JoNhSBKZfDpBaf1GAh/i65WppjqzcrqrV/mp7K1/saqGXXrkXPYdOgs5mw5rvpNdr7CTH+oqrha60sPb5VjIDjLzhWX48lvt+PXv05JZFG4MKoxBsKMBWLItE3oLnEH8tYIjAXChAvDTxsudboMt+dXBcJvPfkGKRAm0XrLUj4IreLPqHzGmKaJUauCXLAvQECwQFh/0PpbdptNXSRMqVRYtUBILSv+pirKn24hKc5PehNCyZVTn0UKClXNJDCzvi8xEOUVbrzy45/4eddJ7u8fLP8LP/6RjUGSzAOVe4aj3pRVuLE665RquVnk1433s9JSx2PtgTOy75bTOAM4G6c/YiCOnS1Bi5cWY8S323W3799U2OCO4KRA+IB/C0n572Ia9tUW9PrwV1mGg4DcTy298YNkgVDEkUgtEGYPifTmWXvgdJVlsinlgvrNyrIFoiL4FgjuQ07HLyt/8Fhz3YW6C8Pf0u04VmC4JV9cGB+v3I8v1h7C/83cxn3QnilWTyuvCoHQ2GxWbpFleQS09EhfYiC4zz4/nSDpdPVmlDx/TEz19QZPobafdp7U7c+fGR/BfgEkBcIHtE6/LxeG3nVrtb/le/PwV+55bD1yTt2XZEPSkgChEQPhWxClVPK3Fu+rshzKeS4YU0/fbfUUBzIGwuyZ46W3SZfMWH8YV45djC166aY6+81kyqlJoQCcKvL3xFbWqepQ/u2mo4ZtfFEg5m07IX7mneda0d4M/Ee+3Ix9OYUcC4T/qUoMhBJeVoWezFa2MPaHPeJnczOmWujcBFp1g/y9rWAP36RAmIRpvMHLfZ0+KBA6v/kafKUsGgRoZ2GEwtwO6jROc/stVX78sRc2m001CKssEBbfVHjWoOqmlGMFkV4DLy3cg1KnG0/P3gGArwToXYq+pnH2n7zeuFGYUtUYCCOXXq0YrwKxfG+erFiZQEBmB5ZZqrxERvjHAuGv2Th//MNbrt/MUfB3ZoReb351YQTZBkEKhEnMnHJfrkFdU5ePFxpv8JHVsJe6MBQ3/ubDZ9HxjRVYvDvHp22bRWrqdjP5W7LZ3ZYNKH5QxdUhlEwdA1HlrVQ/vOuBqyToXMB6rgnpalYu2b8lEflBI0AnVNqtVoEjPaTWMJ6fu5ZDXgOwsLQCqp0JwL5JH37S2A6tFxGrCrc/lJ79uUUyy18wJtPSs8pVNahWBlkgwgM9k5SALxehrqZqoJFoPTN5ZnOtNE7pjc8Yw/1TNiCnsBSPf71Vd9v+xM2Ywgevz3dbj+PGt1fK9tMf95HdZlMFUapMsxZPcSCDnMx2XcYJ5MznpA8LDwfeLuq62kKg+ANjDP83cyteWrDbqGH1CCThue924s/sQuOGEiIMrGsJMeoiwmaCKKuKVGmQFoOL5ARRvrV4H7q8uRJnzvNdVVazMMzeSX0/Wyf7buaU+12BkH2W923FhWGk/ATbfkwKhEm0iuXI36T9GwNhZnIcHrzUQRfjyyl9QFa4WVBy1hnzlJCWftfjmbl/4NhZefaKP57TNpsyuFP9ZmX1HAfbRwnwLRD/UkSKA96BzSimQXkdhkIhqT9PFuLnXTmYseGIpbfeQImrFGHMvJ2W1jeKKYh3cBQIAxn8gVRpkLpKefJ+tvogsgtKMfX3Q/zOuLWstbdtVhn3WGOkXRofCH+4MGRKg58UbmkWF++SCPb4QgqED2hdkL5chHoXk6/RunwLBP+z9E1HKr90QDCT22wVpd+8QjMLwBx+iYGA+s1BOTBaPQyBfKCa3WezqaS6qZ3Sz6onlbk+Asl5yUOjuNyFf07diI9W7Ddcz9/Xtb8wyorkPUyU+yJ8VZZorwpSpcFZoa9ACGgNY/wkDP+fj+BYIPRcGOa3VVzuva6jI9UXBcVAhAka6c/abfyA0hpgtn+uAqExi540+En6dlEnPhoAkFNQys1t9iduVvU3AH+4Cuw2m2oQVqdxVs2nG4w3Bl4QJQ/hHPD2UJ6qqfwNmr9VF6WSa37O5mP4bf9pvG+iSJv6rd0/e6Dsx+r1KXUJWJ10SmxT+X/LEf9N5iYNCpZPD6+nfPJ/4wdRam87kJUo/Z2Fob8tCwqEJB2VtxpZIMIETQ1VstjfkbxKC4TZ3nm1B7RiDKRvDtI2SbFRAIBvKtPUftrpLWhz6HQxznHy0LXYdOis7lwNbsZ8CqKU4p8sDPl3xtSpnUv25OKbjcapewI+WaVMrmP2oWS2SJfZ7eq7MIKjQkj3MTtfuzibkXSB8uBZTXaStjeb7qilDPmz1ov0JcP0PWvhmOqWsjbfjaJP4zZVsUDsPlGAM+fLTKczW1MgvNd1eYUbu0/I644E20NKCoRJTBggfErPCUQMBM9kLQuilLoqbFILhHe5sLikTO5PPHa2BN3fXY1rXl1mSpY/swtx3+T1qrka5LEGzL+RyT5is9lUpnpedPl/5u8y3acvA9O87SeMG1nArAtDDKI0eDopxz8960R1IX1TO2tBuVXCu8eKSp0Y9tUWjF1oEKAp7UfxXfkQLyp16hYZk8cmqc8f7xwpU4aFJlasH0ZyOTWCKHXTfDWW82rQBOL6MXMP+uou3n2iAHd8/Dvavbbc9DpWnhNK6+EdH/8u+x7sSpTqSByCi5lAMV/evvQubqWmqmypnYWhn/evZXKWDlRCm+JyeV/bj+Vrictl53Hj9owBTg0Xi1n8EkQJ9bk1U6JXD9U5NLFrmw+dRb92DQzbmY+BsObC4I3i8vLt2tdlsCwQRZIYCF6VRikTl/+F+okxuL99Q1PFlz5fcxDLLE6brOxXOtAXlTrRetxSpCQ4sPmFntz17RqxSV451ctKFPeqIEOEyZtDkKt2pfuSh1RJcbrMpUpq/caNgdC7fAL4sPS1NsOGv73luZUB2N7PCoudBWXFMAsjyCYIUiBMopV5IT29vvjRdKN1fXwp5xUv0iokJd1+hcwk6fl8oVwR0WzxRjPT2s0Al4t/TM3ij2Aiu12+cQb1ZFpW8cWFoazNUVVMuzB0HwL8z57vVXM/+YNCSVqq1ALhdjOZG+qv3CJ8vcHjgrq/fUNVP7xjkF9S9RlzpZfRrkoztF4lTmV2lG8ILgydFoyJyo0gl54FRxr3IK+1oLcNM7Ia4+tdYeaB7esh5gU2AvpBlFaskkZNyYURhmidVH9H8qpdGObWM8rCkA34GjX7hU9KC4RVtI6JMv3V6TY3GGlhRRP/4vdDeGnBbpUyxKsDoYyBsIovQZRm3xjNUmqyGqY3iFJ9HegVi5IVKfNFQD9QWMpXIJRVWQsu6Kf48a69KL0nsAbKfqUKgTRAUkhf3pdTiO+2HhevyUiN2CQ9OZUIq+mVq5d2bUYJd2pYIHx5gze6zPeeNFc7w6gfUy8xPmoQ0RrXhn4MhPn+jaQKtguDFAiTaJVwlbXxJQZC5xJRxgWYzQDgFQ6SzYWhIaeL8yQoUVkgTIlgqb0nBkLbRG4GK/fRK4v+xIwNR7hzhsjkQtVLfftkgfBzeXGtKcWVspkVVXl+5O6xUHBheN/seWXdpaitKeo2/ij3Ln2Iy1KkKx/IY+btwjNz/8DuE5UPTZu+BcLMURZdGDryS5VoM/eQVhqnXtyElqxGCsttH/6GHRKXqVQ+qdxG8/mYuSR9DYCXKpcyV57GZ8+2zGsQhhYIysKQM2HCBFx//fVISEhAvXr10KdPH2RlZQVVpvIKN84rHqQCynoGVtG7bpVjn9nueWl7chcGv08nx40gjQL2rBsYF4bM+uGLBcIHY16RIkBUORcGY+o6EFbxZWAy+8AyO3houTCcGg9X3vFnjH/9AMoYCHMy+RupBUIaNKp056nKlav2Rb0DkToWiPSkGP4Pim6kD7lIRY0VwOsmKah0xUiPN/caMnGghX3Rs6IxeOu8mLmc5DEQXhku6FgqNUU1scFV+/IkzaVVc71tjG4XM2OWrzEQUheGXkyI7Hxa2BRVorTImjVrMHz4cGzYsAHLli2D0+nErbfeiuLi4NXNn/DLXkxe87f4Xeuk+pRLrKtAGF08/MuHZ4HQCuqRboEXJ1FVC4TWCkrFq6KKidg+aeKqQd44et4qvgxMyhgIl5th06GzHEXAbBon/9hqXV+8pXIlQXtwDJYCUVTKV/D1JjPj3ce8QxKtE5NiVj+UXkbSa0JwNwoBzLxMGG4WholtCn3ovaFn5RSh85srMeyrrabM4VqVKHUVCA1pNx06i1//OoXCUicWbD8hy6QRMHONGlogdH+tbOPjhSu1QGjNKMyYYryzEkRp8Du5MBQsXrwYQ4YMwZVXXom2bdti+vTpOHr0KLZurb65GYzQjIHwaTZO7XWsBdtIByWDOhCaFgh1HEJOQal8O6YlMt9eWUhq5sajlhUKX+4j5bG3waZSrqpqvfZFL1JaID5ZeQD3TV6PJ3ycn0TLAlFh4VVIrvDJf5Nn9QRHgyjkzO0BaKc2AvyHP+9BomeB0Npf5VLpQ0563IVzIywTxhB52jVPTk2RVDLoWbReWrgbZ4vLsXxvrql7qEwjBuKCTrVaPVkHfbEJX/x+CE/P3oGvNxxR/S61TErl0yrHb1YA5Vjt6/uL9NhecKrHT/G7bFu+3Xc8yAJhQEGBJzK4du3a3N/LyspQWFgo+/M3eubxgM6F4cdKlFpzYUhRlpNed/C0qq68ZReGqYGOqWYs/NnibKC+uDDUqXb+N8f7olQqU0enrTsEAFiVdUq23LQLQ8M/7dTwxfItCjItQUYozIWhvE4F9GIgeNcyT3y9qarNBlTbNeo6CAqEcP2LgawaFgjh3Jh5Yxba6L2hl0hclGYuJ5kFQjLOuBl/3DGD4L7J42SlSI+jVD65C0Pq2jB3TpUuaV9dGFL5pIq6WzGWztzoVY6sbMtIIacYCB3cbjeefvppdOnSBVdddRW3zYQJE5CUlCT+ZWRk+F0OXoVCHj4VktL5TVXKWqe1dNNGpazl3Xq/SAeHUqcbL8z3Fs5pWDvOWGCuXGYGOrWydF7jgeBPlKKpJt70w9PQl2tC+cZY1QqnVl0YPGRWK1UQJb9ddVJUas4CIcXNmNodw2kepVMLROsQ6l1bcguE3IUh3PPS1aXtgzDXnQytQlKA+XRhJcJD+AJnfS0rmfQalI7PZjNWlG5eX+91LQVCGdP18sI9mjLqWlsNxSIXhibDhw/H7t27MWvWLM02Y8aMQUFBgfh37Ngxv8vBO0Xztx/HLe+vwd+nvLEZZm9urXkpVO1UvmZz/fOi7itkCgT/jVE6OJwqKsOh0+q4E+tzQZho42ZcP68VfNHE1SmWiiBKVP2B6JsFQr4zWn1UtZCUZhClwTJVEKXM7VP9TzjGGAov8BVOvbdirmuAG0Spd6T5+6uyQNj4FgjhXhXqoAjrSY8p7941c5SFLsyeE1NZGBV8CwTAVwAAT/XaQg0FD/DuUyknjkJmgZBlYXg/2w0yVriWJsUy6UO9rMKlW3pf3rf3s/Q+01POpb+N/3EPrn11GU4W8MuvG505PydsWSZkFYgRI0Zg0aJFWLVqFRo00K7K53A4kJiYKPvzNyoLBBgW7sjG/rzz+PUvr1nZ7MPCJRtwddoZzYVh4//Ge+tyKzRi3mcnZ717KysiCoOQ1QeqmeZupn7TqA7TnFI2GxQuKT+87iktEC43w9I9+u4ZlQJRRTG0LBBmYiC8JnP1Mu93/mcj/DX4lVW4NV0VhhYI1TJ1O90YCNMWCI0YiEr3kuBO4rkweMHNZl2DlR+M2wAwo5LKgyjlHSsrYQqs2JeHLhNWavYp7B7XAuF2o9Tpwr9n78BPu7z3jXT/pfcL7zowUogB+TG++5O1uPHtVdh+9JymzLz1pPLrjfFSBXLdgTMoLK3An9l817thDAQpEHIYYxgxYgTmz5+PlStXokmTJsEWSRXpypjXb1ckmy3NpAKh6U6Q48/pvDUtEJLLm3fzpSrS1KxKZCawSjmdN2DdMMeLRi51ujBjwxHNtwmuaDK5rO+vEqUS+Nv+03h0hn4wpMqFUUUziFaOvlaFQ6UVBtAuha7+zbyseg9mgZLyCizefZIboS+g93ZrFAOhCnbjyG/d/sAJ0JW6MCQHUHRhKCwQsrotHOXfigVCd54KyW9W60CoLBA6mRjKlGm5DJUWCI4C4XIDX284gvnbT8he1qTHV0s5E+BbILTb7MspAgAs3JGtKTNvPamiLq2sq9yYLM248nhqWW8MYyDIhSFn+PDh+Prrr/HNN98gISEBOTk5yMnJwYUL2jPsBRpV7jiA/BJPtTtp+hjPbPXxiv14ccEu47xuDsoHx9nz2iVmrWRhaG2e97amfBsOhI+bMT+4MDjLvlx3GC8t2I2b3lmluV3Zd+XvfjDH+8OKUdU+NLMwtIIoOfutpXR6fpOuax4z9S5Gf78Lj3+9Df+evUOzjZb7AjCyQJhbprdPWgqT8tBKH3LSt/iyClfl9e/phzelOleBMHEjeqc20W5rzf6gPZ03oEhjtBIoqGOBcLnd3LlNtOK4uPEE3HOqtAyq25i5PqW7KY+BkE51LkeqaAkuLC3liywQFvnss89QUFCAbt26IS0tTfybPXt28ITiBNflV6aNSYO3lINJhcuND5b/ha83HEVuoTfCWKYY6FwhyjfEfTnaGSbSlrw6EFpZGNLNcxWIyivUqj+V178W/nBh8NrvqTQLapqZDZYxVnWFyRfrgfI60urDfCGpqrgwODLpvLVb2V0zNTZ++MPzFrhUZzIrrQBKwLgOhHpiMI7ypKPAmb22ZHNbyAIR5VVYvS4MSRvZJHdWHsxCX8ZtAHM1BeRZGPKHntSFYUXn9QZR8ubw4bu6tAqbOTkb5omiLoambmVmThoXx5qkXK48/lJFQ1DCtJR8o8MY7DTOkJtMK1iz+emhNBO5GRMrxkmzBZQX5dnicnGZVoqPFRfGUZOBPdw0To3ATekWeOZeZRU76zEQ/BWU6a9Kq4xV0xyvdZO68eLnsgoXHJERchlUD0K1a6WqVghf8svNViA1e4y0LRDGTz9h/6Vv1OogSuln88fLlzkmeGilcAJAuUs5Q6WBJZBrldBTIDQsEKoAXe9nVRS+QYyD8ppULtOUTUMWzfYm2kkDrfWCKK1k+IhjJOct3OV2G077Ld0WL46Lm9qpE8cjYMYCoZWFIY+BkHcuPU7Cy56mC8PgnFAhqTBAeY4KSyvEC+58mbYL49R5r9VB+lA3O++Dsr+yCrepmAIr9f8NLRB2hQXCogZhpnlxWQX255231K8S3o2UkuAQPx89o1a+1MderjD4Q5f1rby5f5VorToQpop1VYqiNyDK3Rvm8cccE4B2ESnAuJCU2heu7sNsDIF8ufwHmQtDMeeM9K2el8Yp61f8b16D0GvJND5r4TTpwrByDeumcbr5JbaFdFuXm8letHiZRTxRVC8PnG2YmdROup/Sqr16qbfS41QmxECUm8+ICiVCzgIRiigvIyH+AZDHQCgHjTOSmAXpQCYdjPXGcOVN6KlZry2XWQpLnahwuREZYTcMolTeRFYvaDMvIgtMBCsZwR1kJJ8PnirGZfUTNE2fgIYFQkN+nkWDhy81HMzGPFTVheF0eS1pUnhb1ztuvtaB8JcCca5EZ/pppQIhc7dw3kw5e6+nNGs+6BU/ROgE+skeNpxKlLJ+Lc3kaEKDkLY3cfLKdIIo5S4M8xeCXgyEmzHuywEDw5h5O7F0T64sQNNsFoZaBnUrZUE3Hm7G/8zLnBEQ9pMxJh5DLSXfSHiKgQgDlCfpXLF30JVGgCuv3dMyCwTfvKf3Fqj8qazCbTrqW+/BtfbAGQz6YpNqeSCCKDVdGEb9+CEGQjoo/H36vOF2VfYInbbvLjE3wZsvAZBVzbowK4PL7RmElfDCHeRKgsL8K7PamJfdTBaGGfbnaluvlO48IXsK4GfZ8C0QOgqEpgtD/l36LFJWfJVaJHhpnDxZTLkwmHwds+310KpECcgDAa0VKWOq9QUqXIzrwnC6GBZsz1YFWHJT2Dk7ZqbGjn79Dw9a+yk9x9J5lACvQm9mLhGqRFkDkb7xnNdJ4zxtwoWhW3CEa4Ewd2MaVYVbd/CM54OkO3MKRNVdGEWlTp+r1mkjj3LfdbxAdmyFgl8yky33DVT6Wfv2/X7bCVNS+RREWUULxNI9Oeg/eT2On/O4bbR6c7rdWLpHOzgR8J4/vaAwrbcwI8wM0GbIqky7a3BJrLjskrgoAOo3UvG6h/k0Tl8KBarTOKUWCHmHsvFANBpoKSZMd7s82XQvQYvnTtqXUjm74PQ1iNLzXzlxn+c3xg2izMop4loslCXxAZMuDK4FwkwWhv550sLtZrLjpxlEaWSBoDTO0Ed5kqQKBC/ASeC0hgtD6gLVDGSDet74UqfLtMnUbF166Wq8uRFUQZSSz243w8g5O/C/3+Qath4l5RVoPW4p+k/ZoNvOBmB1Vh66vLkS3209btiv9GE6as4fuPOT3zFp1QFx2d+nBAuE9oPQ48KQmpMNN2uITy6MKhogHp2xFRsPncW4Hzzlc7UGuQoX4071LLMoQHjj1Ttu6vZm8IcLgzGGvZXZSW0aJInL69TyxL84K/TuL3MPG72HgWYaJ+faElCWOea6MDSuPVNKgaKNSf3BcuyNUjmTvkVbsbx5a1+of6tw869RrSJPvJcgM/Nj8LZtJgZC6/42uu8vOF0yWbWDKPW3TxaIMEB5ks4W84O2lBfN6SINF4ZBJPjKfbnoM2mtyjTrCaI0J7PZN3xZyVyO9h4pBlGqTae/7j+FedtO4LWf9prqHwAOmAyW3HuyCEOmbcaJ/AtYsN34bV96ioTUP6kCd/BUMRhjipoFalO8ynyv84ax4e8zugWOAN8UCH+5MIT4HK3eXG63OsZFsW1h//QCf91u5UEzR6QJH7MRJwtKUVRagUi7DS1SvVVo69aKBqDOwpDi2VVjU7ZuDITGT3pTnisD/aTfjc691wJhfKDFDBqTLhjLCoReFoYPMRA8XG7GfUhu01IgOOeb171RDRjAnAVC6/Y2o0BIX/K0XRj6BDuNkxQIEyhP0pnz6lnjAPXFdFrin9NyYfCihh+evgU7juXj09UHZcs9FgjvunopPL7MjMeLx9ALoixSztRp4o3O7PP0p13ewMpTnFn6rFJwwYmzxeW6WRZqWbWFzS9xYsCUDXho+mbd7QbShWGEI8oT5KklgtPFVINkhVuuRD3+9TaPTDqmbqbx2Qh/uDDOVt5jdWpFI97hjQkXLBBGpayV8ItoaW/frEVQq7IkILdICL9pBlEy+X89zLSV/mRV2RWObWKM57hfqGIWhtZvvBiI7cfydWWSwrcqKdtwXp5MXJ9mZjbmcaHcrAVCvx9K4wwHFCeJVxkN4LgwivgxEOUayoQRpU5tC4RyuWkLhOSzsrY9YD4GIq+oFNe/vhyvLfpTtlwdsW9uf6UaeV5RqWF7M70eOVuimcLK68NTaEifTYfOAgDmbTuOlxbsVtXb8MWYYPYYKV1rynPjiLRzlwtUuN2qt7sKF3+f5cXHFBYIxm9nhD+GPuF4R9rtMrnqxldaIAxn41Qu47fTROMn5TqyYkeS+4xBbvkzG0RpBu/xMLeOZQWi8oUjMdYTbyJ3YZjvR2+fXG5+DIQ0GFYuE08B5Nog5N9EZcu7XG8adP2+1e5nJSXlLplVOlwLSZECYQLlSTqrpUAoXRgSS4Vc25TkC1u4abXmNOBh3oXh/cy1QAguDE57qeRT1vyNM8Xl+N/vh+T9q25Uc/srTT08V+LUfRCY7VfZhzoSm6kUDLPj9XtL/8KMDUfwx/F8cZmv03CbLj6lqpAq/x4jWCA0Vq/gWCB4cTCAfq0HXytR+sPOIlh47Ha5kl47vtICoTsXBkcmzg7oFnszsBSI7XQyryp4WRi+Zi9J24oymmvvuwVCrUCszsoz3Y+hC8PCY5JrgTCxTeF4S4+BmRgdrcvLqMqrXgzEhXIXek38FU98vVXs59qGyfyOKAYi9FEqoqc1XBiyctFuJrNUyP1dksIxJp4W0ZXpbloWiPIKt2pCIV+CKHnKjDqIkv8mamZiJt53LZQD/ymNY26lX+UbJ/+9RHIOmbavVYmQjZOnVbLcAr4WklKuJVggtJ7UFW6mclFVuJjhQ1RtgdBWLvTwR6iH1AIhncY+Jsqz77rTeXMsTFwLhM6DVesXtQWCf68wJrdI6AUUSn83ozCbcmFIfvNZgYhVuzBGz9tluh+9zVZoxEAYySSFG9ei0UZ6bszFQPCFN7qHS8orFFkY3s/7cgqxL6cIv+zOwe8HTgPwWEPuu049K3WwLRBUSMoESg1Y6fsXkF4z+RecshtSS9vkPXgTYyJl5XnjHBEoL3FX1oFQt7/lgzU4oqi0aDblUzro8uIxIlWVKPn9aL6JqdoZisQlr7AUlybHav5u6qHLFG923DdQeZ9mZuQDvOdUquj4msVh3oUhx3O+vUuFh6i2BcKt8qFqKbRa86cov1sysfvBBiHcY3ab/L6KrlSeyivcmtYrj0KpepRw22mhWRlW1Yf3s3ISM+kxN3JhWIuB0LdmKH+zqrgKYxfPAmEFvXGqvIJfylq7PU8G43MqfLOsQJioA8HjQrkLLEr+XSBfUtztXOVLqM3Gl4diIMIArXOkPJ9ShUEZaCl1P0jznXlavxAAJhAf7dHzyjQqUSqVB0A/XkLWTjK48hQI5c0r7dLU24vOm5gV8gwCKd3MI//JAu1ZW91M/2HHmDog8PbWqYayudzeinLSgE9fLRC+uj7UFgghiFJ7kFPWcnK6+XEbevO3uJUHzSR+t0BIrmWpAqH1YHO7zSm4+i4M/nI9K420YiODYjItAwuDL3UgzF5OVtypUsQYCB9ru+iNCU6XmxsDAfBdDNxKlAYvClIZpMpcILMwyircmjEQ0tLswv7YYOMqC8G2QJACYQKtk5SWJH8jlj4wlCb3cpm5Sm6BUA4WdSoDwATioiPE9cze4mUVLlMPa5kspoIoJZ8l0mjdMHKFw7egQsA4E8PNGO75dB06TVip20bvWceL16ifGGMom3TglCkQPg7IZlfjlviVHGBjCwTPhaH1ti7ZhjKNU+bCsGCB8KMCYbfbMPSGJoi0e0y9gtuv3OXWfLBxszC4DxvrgupZaf7KLZL9Jp1BUlDUtLZozQIh/NezoHg/+3q9ihYInxUI7d/KK9RWMoHkuCjVMl4hKTPnmefCMJNmrPWSYBQDIX3pAOTHThogKrax8etSBLsOBLkwTKB1ki69JBYn8r1vvNIbVVqDAFDGQMhvNJebyVKGkuMUCkRlepreZFpKypzaZa/l7dRTy0qJEN3oanOoLADTRAyE5w0/MBYIxoBdJwr020A/4M8TNCk36fKsMkqk51OqQPiajunrekp/utcCwW/v8S8rgihdDI5I9QWvV4kyFIIoI+xAozrx2D0+E45Iu1gLpLzCza1wCAATftmriuY3U/ZYCWPq46iVfeR0ufFXjrcWCmNKF4a3Tx7W6kAY4xcFQoiB8NWFofOb08XPwgD4FgJezAtvbNI6ftIXKlPHWOM8mannUSYpcnah0uVss9lk89MIVgobNFwYVIky9NHSgKWlcwH5DXi6SNsCUaLQ1JUXuNKsHBdl3QJRatoCIQ3oVLdXuTAkTfaeLJSsay5633cLhH4qpzCI6aG2QPAGFslnZvwmAcgHHanlyWcXhs9BlEz2JiOmcWpcNRUuN6cOBN/1pVdwSK9GhK68kn6kM6dKMbIiu1yCAuHZ15ioCNhsNq8FosItcxlIWXvgjGoZb9+NdEiuy0fDhbE/97xKUXfK0jjdle352xKWW4mBMOs2DJYFQu+lqKzCza1ECfAtBLx4FzMVRwWl/eWFe/RENdW33nLp71JlhzGv8sOzQNhs/BfZYFsgSIHwkbjoCJWrYXXWKTHoRZmpIfN3cSwQUpQXd7wjorIP85Uoy0zHQEjdKbwgSqGWgCCbt9P//nZI/KyVLid/IDPfYyAK9S0Q9RKMXQ1gaouI7GfF725mPDU6IPdpn/aHBcJXBYLJLUrCZFVa3TndnEJSGgqT2+C4eT+bl112rDWOlVGQmGiBUDQTYyBcbkvzrvALSRm/Tar7Ubbx/N+TrbaSWYmB8Ic7RQtfrzteHQgrGMVAaL1l897IjdK9BVTXcOX/lfvyNNvw0LrFjeJJKhQuDMD7IpJ/QZ29Z4PNVGnt6oYUCB9Jjo1CLYfcB5dXVIZ/fLoWgHcqb8FPp1d1zOhii6sMoqxwM9mbvs3mjdJVopz6W4tSWRYGxwKhuEK0+izXmHNAHjNRhSwMQxeGccduRZSkcg2mWMoY4x4TJcoYCEEWn9M4NR+miu+cNlamM3ZxCkkt/TMXX60/rJJHui9aAWh62+Ihi6Hx1VrjFiwQ8h2RBlFqWSC4MhlYX7jrcPtRvhR4vu/JLpQvh7IOhLYcgMQCoSuRvA898c2kYhshrUTpi9Ksl63kycLg/8YLouS5HIUYMinqeCd1/2b2RKn8NK4TBwBwGbx4KCfTArzjiCyIUmKB4FliKAsjDOCdo4SYKNSKUZvND1dmRAgWiPTKQMsyPQVCcbEpL1zpDSDNmtjw9xlc8+oyrsxaKZ+y7TAmd2HoWCC0ZBPQihVQxkz4noWh78Iwd7MrBg6NQV76s9PgjaZj09qyN69yl1tMwfV3EKXyDUR5XTKmjFXQf9rwgig/WrFfVQysrMKtGyjpcwyECQuEdMxcW5kTL0VTgZC4MKyY1vkxEPrr8JUOfh9/KhQIQFEHwiiIUvjFxIE2MxeG28Q5ALzHk4dggQB8K6FvnIXBf0jyHqjc6bw5IukpwVYQjllMlB0f3X+NWE7dSBlzMbUFQhhHpC6MMqkCQVkY4QnPhBbniECCQ9vvLsyDkV5Zu0AWA2HRhSENqpQ+rHJ1zPqlThd+3pWj+TvgGTxkFgiOFcEbRMmXTVxXS+OWvfEzc096DqfPl+s+kM36hPX89UzRj5kgyqgIu8pELgRS+loHQutt3CgfnoFZChjjzYXB44LCmqX1cPT8Zv4ES1tqnVrpvTfwfxtVvwsDv5YFokwnjZMrE0cOX1wYvBgIt5uJLoxmKfHiBmVBlAYuDOGaMqswG7WVXmt6Dz2xKBkHIQYC8C0OQu/wVriZ5vVrNo3TjIvJjBWJh3D8Hu7SBHe1TRevQ3PTecuPlfAyJw+i9LoweLeqH+ajqxKkQJiAN27HRUcggWOBADwXnqAwJFVq53pzv6tvXPn3xNhIcUA0688trXBh9uajum2UgTzc6bxV0eVaDyMzMRC+a/ouN5NNo67ejnG/njoQ8jf06EhtCwuDcQyEmzHVoCkoEP52YSgHC3UhKX62hJYYHheGt5chnRsjlZO2Wqoou6suwiO1TphHei40lUMD/aZCEUQpII2BsOLC4J0zXwxJvFoZR86WoLjcBUekHc1Saom/8dI4jStRGssgttFpK4u/0FMgorQfFdGRdlHB0Mp40cPo3tWyapiNgeDX9lC+Pfh2rwqrCeOkcD8ZWiA4MRAXxBgIqQujMgtDo5CUlSJbgYAUCBPwTlFsVCTXhQEAHd5YIWYoCLn4shiIcqULQ98CUcsRKd6gZhWIMqcbB08V67YprdCXA+AEUWr0xZvExrOed3lZhRvbj+bryqSHXiClmUFeOTmWahWmtkAYZWEwprYoCZkYvrsw+OsZBVExyM+hN/WWj1NRSOpfPS7j+osvOF2yacuV/ekFWBrJK6ClbBkNj1pBlI5IH10Ykp3Zl1OIBdtPGPr1zcRNMOa1PrRISxTTthnk/nLRAqFx1rz3ofGBNuPCkE0lrmuBUF8XAtERdtEa4Ms1b7RKmdOCAsGNAeJYIDSsaJfVq6Vqq4e0DgngvQ4N60AwtWIkVA8u4LgwAH68A7kwwgDeuB3viEAtDReGNOAvNkrIoJBWolRaIPRjIGo5IsWJkcwOhkWlFZqTfgmoFRmOBUJ5hWhZIDTLBXs/T1z+Fz5ZdUBXJj304iBMvZFBqSBA5WJRzoVh5MJgTH1ORBeGz9YW/nL1vCRqeCWnNStRKiLc7RqpYqVOl7x6Isc875XJNw3CbOCoEmE9dQyE537xVKI0/1YsFePZuTvx9Owd2GowH4reA0pq0hYCKK9MT5Qdd+nbaoWBBUJUCi1YIPSaymYC1elUaamTEhXprZIodNGpaR1jASsxuk+0rIBmLRD87vlWtCvTEw3WkyPILogiyGRsgVCXWL9Q7kKp0y3bX7ESpU0jC4MsEKEPNwYiOgIJMepKaEpio70pmAJKK4I6BkL+XW6BMOdYV1a746FUZJyci16vlLVsXclF/8uuk972khWOnVWX3JZy42V1dX/Xz8Qwvts9dSCkD1h9yw9MxEC4GVOl5QoKhN5bSN1a0Zq/aQ2oejU5PN/lMRBGFmzlIGez8cvlljrd4mRh/O1KZdfYGAfpfmrts1GhnAotBcJXF4ZkB3IKPQrrgbzzWs0B8PdZuM4iJG/mWTmee7JVmvwhxQui1Dppvhi19B6E0uvb1yDKSLtdfI55rznzghpbIPjnjxsDwXVhGFsghO96FVd5CO2Fh7twjxpN5+1yq2dXvuB0yVI4pXLZwK+JQhaIMICn5MVGRWrGQEgRLAdW0jj9YYH4+7S++wJQ+yt5D0vvvuu/+UjXfWLmNvGz9CY0itC+puElur8f5cz5IW7H7BuZwkUh+x3qh6GWa0a6ji8WiAc7NtL8TfZglVwbyockT37ewKUlRoVLPtOh3cYfpDwWCG0XhjKuxCzSlsr4FAFDC4RBEKXLzWTuFyOkx/S8xqR5St78ZS+nH89/4cHiZkB2ZdXahrXjZG25k2lpPLwszYXBBGuGdmue9YOHbgxEhF18kBnNJqonpxaWYiC4QZTqdZXL/vvb36hwuQ2tIap+lC4MkxYINy8Lw+mSBVBK0U7jtCSu3yEFwkf0XBhSYqLUFgiVC8MgBsJut4mxFFaK4hihskCYSMHSi+jntpcs1vJlChj5+D9ZdUBzKnWz5kZpM+VYw5g6BkJQjJSDvncl73EU8uFPm4iBiNJ5o3NpDOpG1iDGFDEQBsekwu2WvcHYbTbuG3+p04XiMu+1onxT9bW2h142h1QmPYyCKAFoDso8pCWnzSrrX29QBysL++adzZbhZIHHopGWFCN7dZSeY6MHsDjDplmFGeYDW/WuVz0LBM+FYUWRNHpoaykQvEqUvJcgvgVCvszpYvh28zGuC1APrwtDboEwMxcGLwZCWVpdQLuUdXAhBcIEPLNubHQEN+BM1Y6jQKizMPRjIJJio8QgpoAqECaeBFpNzFRsVJrslJjIKMT6g+ryw4DJoDKOgqDXj8e87Nmvu69O5/bpZkw8J3UrZ1EVHjx6PmWe+VVA5peWWSBUwqq+y7IwwJuumr8dwDP48WQudbp1LRBmXBE8lOds5sYj+M/8XTIFxeiSEC0QykqUEb4pEMLxM2t9MJJLeGssLvO+XaYmeTNdPG4n3nTe+gq5pbkwTJ4TPQUiws5PIwTkLgwjSyUPZdvGdeIw8pbLxe9VzcLgHUueeC8t2C1LfTezDy5RgfB8t5sMJuVmYZTrKBA2/vEPdiEpmkzLBLxTFFdZc9+I2GghH9378BSCF+02zwNZKwbi1lb10b5JbbRMSxQtEL6Wi+WhDC7jBVEKb6TeNwt+X5ppnIosDD3M3AuCQqbejvG6qiBJRY45b5ZJ4SGrlQfP4D0nSXHylF09n7Je/QUhi+NA3nlMXnPQu47KAqF0YShiIJj+calwM9lBt9n4CsAFpwvFkuuONw0677MRyrbCPAS3tKyP7i3qmepDGQUvECXRKPJ9sEAUVVGBEK57QVEUppmPr4ydkkorD2QU5NCST9iAGRn0+1Kip/DabTZERdi593BUhE1883aL27RigZB/j4yw4189LsPszcdwIv+C5osH7xYyKiS18W/PC4jV+OatR86hSd141FZMX6AMlhWzMAxiIDyTaXnaxEdHoLgygLJQy4UBKiRVoxDKSxsRq4iBYMxbN0AIwtQy//e6KhWP3NgUACQWCB+rE3FQWiBMpUJquTA0lANpaz0FQu+BekebNPGzciIyATOyu93ygeO9ZX/Jjj3vgStYVrTS2KTnU6z5IVggfHRhnDh3AYwx3P3J75i79bi4XKmw8gLBVMqo5laEeQa8eFwYakqdLpRI4whU27Vm9jVqKw3YNFIqBYVJadGx2WyiG8OSC6PyEi0qM78OD2HfBMXmXIna+gB4DiU3iFJLPisxEKI1wNxJ0ds2A+Nes1ERHveFcPStuk148gmn0iEpBsaD90DltRWOWXFZBfpP2YD+UzaYsuQKUv22/xT6frYOXd9epWoj3G/CvRnhgwVCGDekQZRK97hmJUqKgQh9uIWkHMbuCwBwREknwmKVZYE9vwkXidbFJt2uaIEIoAtDTwavOZTfTsuFIX0T0YqmBjw3ntY4Vzs+Gj1b1gcAzYA4I41fkEXvtuaZ5gXLilYam5t5z0ly5UAgHAtdF4bS5i7hQqUvtFhxfnilq5Xyyy0Q+i4M5XWnNRhdKHfJFDf1cYLmb74gVSZtCguJEq00TgBwVD7wCjTMwjyEc1ZVC4RwTJSKTVplaXvpfkkDXz2zoWofRW8hKeMjbaYSpRS9wD+3m3/NCkqFd5xgsm2bQWmtEB6UQt9asVO864EfA+H5z5smWw/hGK/Y65lgq4gz9ihdaILsRnPouCSVKJMqKw0fyDuP6WsPA+DNTmuj6bzDFa4Lw0T8AyA3uTtdTKb5ClkcqhgIJmzXu+WYAMRA+OIO0botTAVR6lgg9GIC7DYbalUqbJoKhIlJrxgMBl7GVFkYwkNW64EvdWEIJceFAU84rbyBTm9/AeBEZcS+QjzFttWuL+nDiDH9h0eFi8mmpNeS6VxJuW7siK8xEJoFszTk4C3VSuMEvEqfJReG208uDPDlqq+o9MmYshKl/sN37pbjpmecFE6aaReGngLB+BYI4ZqpWhCl/LtwzKSpuDx4LmR+W88GfC3sZnRcAK+lyZvGaUaBECwQnufA8r25yK4MtE2pJVcgbBoZUsH2YVAMhAm4QZRR1lwYgEfrFd5WoyJsonVCbXYWzGLeZY4AZGEUlVo301qdTti0C8Nm033YxVVaa6TZAFKM6jUAwhu5QRtFEKUymp7TqcqFobRApCXGiAODAC+CXEo2V4FQKgzq9VQpwTr763S7xf16vlcLzZgeYWZZrT5lm/SDC0Ma62EUWOt9A9RWIKw8OISmvtwbvH6UCkRakrpUuHIuDL17bP72E2haN9504TTA/D1rqEDoKGnKNE5rriylQur5L8SxaI15UmlioyJwQVFyXdmfdIwwM/Oo0ELPkijo63aFC8NoPHIxJo4TSbHqekJKC4QNGmmculsJPGSBMIHWXBhmiJEoEJsOnRWjbGOjIsTBW/n2zrteBR+8P10YH600rgrp9W1aHxiU7fXMhjHREZp922yeQCNAu9a+mYeEUVAhz4WhTNNS4maSIEpFDITwNlunlgN9r20gW0/LoiEEafEUCGEXcwtL8fGK/WKwpVR+eRaGfrS+y+2dXEyv0uAZRUVTteVAuk3zF4hWW20XhvqYaQVRAvr7pIXwsBDiMPQKfukh3C8qC0SlAiHeV1Bn3Rhdyquy8kzWgTApbCV61iMGT3CjEpULgxn3pepb0VSY/yHaIAZCejkIBft4CkThBSfmbDkmu45NKZXCvui0dSnGB7NZGG43Ey2VphSIEI2BIAuECXjnKN5kDESE3YboCDvKXW4M/XIL6lVeGLHREZoBN1wFIgAWCF+wbgT0rqHnF9Sb2dQGmzhN7nkNF4bZFFS9B5xSwZBaLLTM6gxet1RynMICIXm4dWpWB99v8wZEalkgLk2OxdnicrFmgFx+T3+Pf72VO6cIY+o6ELoWCJf3aOiNQ2eK1YpKWYXLU0DIZpNFufsjiFKmQBj0IbqYeAqETqCqFkoXxjUNL8GyP3Mt96NluUrjTFYmc2Ew7dknef3rthHjETz/e7dJw/GzJfjjeAG3/cId2Zp9uZk8s0VAUITVk+4Zy+dtK28sKAxRkinZeUi3GVOpbPAsC1uOnMOWI+fE+xMw5/IU0J8FWFAUPd+FQ2TowpBYIKSzLQuoLRD8UtYUAxEOcOtAeB5ob/dto7uq3SZ/ExLKMcdGRYg3pLoOhODCUMdAXDCRhTGie3MkxERiwPUZhm2NUAZRWq3UZrZ5vCNSc+D0WCA8x1sr8JOXgqrEbeDCUE22xbQrHUrblJTLFQghYFY60ZNy7XqJyiApD4JiylMUBdm0JiRjYJw6ENymADzHjIlvUPJtSFHOqVJcVoEub67CkGmbAShjILS3p5aXT6TMAuFdzjsDehYIvUwXLYR9Kax0YTSsHYc68eoBvlEdjcJiin5iFdlaQhaGdL9cijoQxm42c5YebzyC578j0s49TgJ6QdVut1YWhpYLw7rbSEBQGLxZGBouDFmQufELnbTGgpmZcs0EhCqzMITjazwXhtdSybVAcGIgeNYGS3PPBABSIEzADaKsvGDvuz4Dnz/YTnNdu83GrSEQGx0pVs/TskBItyvcIMp5F3g8k3kFdo3LRHPOzHK8twgr+NMsKiU2KkLzRrUBhhYIsz5NvVYM8oHPzaRvksZZGMJAwJjn7V6aISC9+evWisY1Gcnc/oQBmWdRMRyUmTILw9iFIV5rOrbQs4oYiGNnL+D0+TL8tv8UCkud2Hz4rEQE303XAvKHnP71KgZR6sRAWEFZSKqWIxIt0hJU7ZY8fZNuP8Ku1VJYKlVpnIzJLHNmFAjPevzlT3RrhsaVyo3QxGtl8v191c0Y1+0WLbowlEGU1vqWIryZG2ZhSIPMTSgQUsy6PHnySRFjXZSVKA0VCLe1GAgb/yXG6njsb0iBMIFRGqdeDQO7JB9dSmyUdwpcwZxWXFaB6WsPIbuy6IwsiFKYTMtE+pGeXMlx0TJTXqAxYRgA4ImB0MJut4lv5loxEGYsEEZpjYBcwZDGQGhbIKRBlN431XKXW+YflZpb46IjNWMqhAGZtz/CmKRX1EqWhQHjIEplwC5PKmX6mrANNwPu+vh3HJbMUWJtQOM3jtCyQHCEE84P14XhgwKhLCSVEBOJl+5oJWvz879uNHxgCf1I68VER9hRu9JcLZVWORdGVd4qYyIj0KMy5Vn5APTMtuqbCuFxYXCyMBRKhVdp8V2RFBQGKzEQMTpzdfAw88IhYC4Lw/NdUCTMTaZlJYiSn8YZZP2BYiDMwNPbpf5VvbQqm40/4MdGe4MohYv5zV/2YcaGI9ztipNpWUi95A2qNhjPL6Bew7f0LMCKBcKu+fSxwevC0MrCMDMgeCpPaqPcPIOxC0NaEEaqmJU5XeLAo7RAaE2bDXgHaZ6PVjj28Y5IlFWop2pnTDHYMf39rXAxSapp5XnWaS+uJ9nGYcUEZ5Ym09JoqhUDwbsP9VwYWoqWHsosjMSYKLRI9c6eeXVGMlpJpnzWpLIfaWxPSoKDK6f0eLqYcRCl3iG2S9xlyqm/bRxXmlkYY4jiWOGE61X4yTuBl7W+pThdyhgILReGd28i7XZERdgM6y8IuMy8cAhtFfJVuNwYPW8X2jZIUs+FoXgp1KKswiWeFzMKBLSUP7JAhD78adi9C8td2g91u13LAhEpau+Ctrrmr1Oa2/XNAqHermY+sUmsXq9mU+hioyK0+7Z5LT5adSDMDBxGb+TK36UxEVrudKnfOD46UnQRlUtm9lMqH/bKabN511VUpLYLQ1iklQGkLGUt7IN3u/L2FZK3XSuXhG5QmYV+NOtAaBWP4gjp7yBKoT+pBcIX3BJlT0CqYErHjwpFJUozSphWG5mvXNHEBv41ZwbGDApJVZ4cf5SyFq5hQwuE9LPN2vk2o2eIFhyFgNuP5eO7rcfx0coDXgVWTOP0tDF6oZHGOPEUiFjFPW4D301HMRBhgNE9p2eBiLDZuGWQPVkYlQ8LF/9Bw4uBUPqj9eBbIGwqTbbXlamafXjTs+RvM2YxrUDouDBssPkpiBLQe8R59lEymMuyMPi3inQgcETaxUGszOkWHwxKF4aeu0AMrNWZVTBeo4y60gIh3xt1HIcniBKijGbRU9YsZWFoLJdbIPjWCAF/p3EK17kQa1MrRl1S2Aw8ZY/3oGBMHkTtMrAaCWg9oGw27/0tzuwpHiMzkvNxMe1S1p7tCkuY7J8ZtJQN8V7SzMLwfo6w20Tl2wxGLgZ5W7l8h04VA/BUtnQr7h+zMRDScYx3XSiVIa3JtCgGQoNJkyahcePGiImJQYcOHbBp06Zgi6RJuc6AqszCEJDGQHi1WHkbno9PWd5YD27pU44F4oEODU33aTULw6wCERNlUAfCH0GUBlkYnjbez9JnuNYkXsLgFmG3wW73FgeTWiAiFS4Mm2KwkSIMHNyHtPBQ0kghrnAx/Pe3Q7J9ke6P8gFS4fKay628meoNvv5wYUgxkkt0ExkEUfIGab3+vBYI3+KFRMVMcrMlavQlnwsDYAbPtsILTszcqJ5CXEA0QKie5b6HUXoqUWpbIIRr2Uzgobpv/nIjBVD6IiSky5vFXNC14I6Rtz18xqNAlFe4RZeysGnvdN76J1GIm4q022RWKgGVAgG+khxk/SE0FYjZs2dj5MiRGDt2LLZt24a2bdsiMzMTeXl5QZHHaBDrdnmKzroaWRhR6hgI9QPf+10ogRsVYRNrSRjR7Qq1XMq3YcBaBLNlF4aFLAzNNE7AOIjSxBuFJyvB/O/SB6W6Nr0HIZJaOJdi6pnTLSogdrv8mItT/3IuLDEGgrM/wkCm5dv/Yu0hnJYUl2KQ75DyIeu57hQuDBOnS3fOBJ31lcqfGWWDZ7mRIqbKGrgw6mukzar78/wX0jir6sKQipUY6+1LGqegnM7byAbx9+lizd9sNm/nQi/SGAhfgyAY42ciqdM45f/N9a1vgdBCqZRbsTi5zLg8K5soX4IEBQKAOPmV1cm0BMXDEWlXWV+fvLk57HabyoLMGy+sxqT5m5BUIN5//30MGzYMDz30EFq1aoXPP/8ccXFx+OKLL4Iij5HWnlE7DhvG9OD+pmmBiFbHQCgvEOnXK9MT8fXQDlg+siueufUKU3LXqeXA7vGZeLCj3MKg3I5esJnSnWo5jdPPFggt87mZwjBGdSCUSF8ianOKvXjkkQd8RUty1711IORXkHj8eTEQOkGU3hK//PO16dBZ2XdlGqfyIeuZuEkhkwl8jYFoPW4JFu44YaqtVVn05sIA1LMbasGYJ/1WUHZ8VSAEpGdeywoin87bOIhSD8+MqnxrgDTA0ipuxrguAvEhr3R1WqpIyseojof02EZojLNaWMnCUDY9fNobOJxf7KzcvlyBMOvCiI60i0WwBB7s2AiAfP89aZzqfsgCoaC8vBxbt25Fz549xWV2ux09e/bE+vXrVe3LyspQWFgo+/M7Ju661KQYrsasWQciyluJ0inxlWtt1maz4YbL6qJRnXhLo0AtR6QsBsNmk5uyI+w23Zkh1Zi7ZIWBxOyNqhdEaYNNrLuhhfk6EHoPP+VkWt4vWv5jsU5EhNwCUV7h1qwD4XVhqPuLihSuCbUFQpBda2BVrsMgV/iUpYidLm+aqvKtVQ/9GAj9+Igth8/JBbSAXhaGURpnnEbcCK+/Eqc3Qj7BIX/oa90p6vkcvA/tKyuzNv5xjaScuaQjWRaGySBKLWyQFn8THubCb1ULouTNhSFc98oXDQshBtoxEIYuDO9nqy4MMy4WcV8kbRljMguEkOIsjA9mJ9MqES0QEYiMsMsUYOH+lrqMPOeVZ4Ew3I2AEnIKxOnTp+FyuVC/fn3Z8vr16yMnJ0fVfsKECUhKShL/MjKqXn1RifK03d+eHzPAewvyZGHwgijtoklQ6y1KK2ebt7Rp3XhcdWkibrtKHRApHVyz8y/IFJWoCH5+sUoGwZ9q8oK16gvVqwMBeB5+epYSM5YO40qU8geB1LRs9IYunEtp5Lg0wE96LoXDrZcezBuAvBYIvixKBeKrdYdlWTu8fRAeXlaeK1YC0JRIZ8Y0c2UYBf6ZDaKMiYow9fB0MSamcEbababrCyhPl9dtYMN3j3fGb89156Z/MsarA+E70lRN8VKu/OD5zTcNwuXmF5JSxkBYjZECtO9do6J30lOuVW9HC1MvNpX7Ir0XT50v4wZy2xUvBUaTaQnB14LM0jFaOM7Rshc/voIUbAtE2NeBGDNmDEaOHCl+Lyws9LsSIR383723Lfpeeym3He8tyG7TqgMRqYqBUA6CVm716Eg7fhxxA1fpkCoIbiZ/kETZ7dwANC3MDhBCK7NBlDe3qIcZ64/otqmlUf/ArFyMGU8YJI+B8H6222x44x+t8Z/5u7jrRokWCM9NX1YhSeNUmI6Vg428H+NKlFrKjDIbqLjchQ+X7xe/865PYaCz4sLQG3w3/H0G50rKsSe7kBurk1/iPX9m3rRlWRi8GAidIEpHhNQC4Yk5Mkr3ZUyewmm28JJnX6T3mfehHRsdgYza8tLX0v1yKtI4pesKMpnFBnWMkzJTwBcYY6Ym0xJuoKpMpiVgVMdDegztGg9YLazM0CptK3VfSFHOxmlsgfBcY8I+RkXYvXOACC8jCgsEL9gy2CaIkFMg6tati4iICOTmyiewyc3NRWqq+u3a4XDA4TAXIOUr0tuudnyU5qASwdGYdV0YihgIs54E3vaF2gL89trfIyNsuvXxlZi3QHgGVKMbqUndeMx5rBNSEhyGftM4RwTOaMSPmStNa8YCIe3T+0C22TzZKloKhPjWIKaeuWRvx7xgQG4WhjApEM+FYTA487KB1h48LX7mWZqEh5e1LAztg7h8bx6W79UOdj4nUSDMjOEy1w9PFr0gyki5AmEGl5vpZmBo3WPKfTH70GaQH0+XJJLXXpm6Z7Y4kkc+Tuq1NA/DRx3CzfhBjWIapxB3Iba3EmOgZYGw6MKwYoEwWTcGUCgQGgOQcP2ZTeMUfhZklrswbLLfAAA2fhxPsC0QIefCiI6ORrt27bBixQpxmdvtxooVK9CpU6egyCQfxLTvQL4FQquQlHo6b7ULQ0MezjI9U6/yYpa7MPQtEErfptkL1qwFwhFp18xwUKJV/wAw9zCyGpwmPW5GA6/w1iDMmlpW4ZYHUUqvIZ+DKAWTKl8GbtyEpBveNSKso7V/0ktSeAhbCUBTIp3QSFlGm4f0J73pvI0UCL0gXWV/ggvDbOAlwIutMecaevOXfWIkv7B9MbUW2nOwaGGTBOwqPBiaEzKZwc0Yd3zTms7byouxsL+dmtYBAHEuDytpnFoF+7S3ac5iqWx7WCMDRhDFrAVCwGuB8O6L0IcsiBI2vgIRZA0i5CwQADBy5EgMHjwY1113Hdq3b4+JEyeiuLgYDz30UFDkkQ/+2u206i7wCkl5TKpyf7deFobRcj0lQOljjFAqEAGxQHj+m03j9Kyk/zPXhFeJ2clxdC0QkKfQSY+b0ZukcLObCaLUS+MUzJZObhqnsB8aFghOwR2ZAsHZnrCO1v7FREWIPt94RyRKyl2idaSWI1KzLocWMgWiUrYImw0VGvskfUhwLRB6CkSEdQsEY8ynKpRK8YXTp2Xdkx7uUslkUdLpvG02/kuJHjZJ594HYKUstqrUgVAH4QJSBcJW2Y7JtmkGYX/H3XUlNh46g1tbpcr61kJmgbBZm33VTNq3gFSXFywQyrLZqiwMk5MAeWMgpBkXaguEJxNNfQ0HuxJlSCoQ/fv3x6lTp/Dyyy8jJycHV199NRYvXqwKrKwu5H5Y7RuQ9xDXMq3FcOpAqLMwzLkkjORSXsvSplEGLgw9c6geYhEWCyOJUUs9BcLMG4V0cOZuXxEE4TJ4+EpRBj55gigr11UFUQrmTnU/ehYIj4xMUynTC9yy2eT7EGm3ocLNsC+nSHMdwGMpExSIBEckThWViQ/tpinx2Hm8QHd9JefLKlBe4UZ0pF081Ha7TfbEke6e0eNO3wLhHXDjonVKpUv7Y/ouDC2Up8TXgV05Q6q1DCl5EKVbcc/a4LsFAmAyn7xAlEYWhi/TedeKicSgTo3F5YYWCMlnq0GUZl2egHwME2IgrkhNwO4T3ow/YQwVjq9ZK53wcsm7fuUWCL4FNtgWiJBzYQiMGDECR44cQVlZGTZu3IgOHToETRbe2yMPSzEQ0d4YCEFbVVkrLdzselaE6xtfopJJINLAhaHCogXCirnbaNCJ13mL1FMghGNjqhKl5LNLloWhv14kxwIhrB9pV9aB8PznKX1eBYKvDHgCQfky8I61cExtin3QyvhRngNpkTFBgRO2ExsVgSGdG/OF0UE02UssEJrYND5XYrYSpdliaS43cL5MmEjL/PuV8vozioHQ2mM38+6TDfy3fj2kSoIokcSF4SueLAxtC4RyMi1h258OvNbQiuINDpYvNwqKlBVns9tkQbNGVDUGok2DZK4swnVodtwTLRC8qdIVFghuJUpSIMIL/RgIXh0IjSwMiQXi0OlifLPxKE7mlyq2ZV4GvXu0lyK1UypmVIRdN35CuS1/l7KWPkSNutbL5dfbjHBTKwwMKpQKhlR2o2h8IUeeW0hKYYEQjim/DoR2FgZQOcW4D1YdTy19uetKitb145CkMV4S7ymmdb7yDd1us4n+aisUVLoxhDdjvQeMoQXCdBBlpOlJqgQLhHIeDN31FN95lSjNInVpWnZhSNwU6kJSvmsQHheGtsKrtc3L69fC2DtbqdZT9s2Tz1IdCKsWCB+DPEvKXbDbgFZp8pRc4TT5GgPBG9uiFTEQPIKsP4SmCyOUsR4DwbdASGMgthw5hy1HznHXNSuD3uCg7EdZB8JMoBZT/DdszwlAMrsNLWppzAEB6N+wdjsAl1AHQn8r0lLZWnOU8OAVkhIGqUi7fCIc7+HX9ttrWSDczLcUORvk16fyYaCMohcQ5gCx2YCGtWMBALlFHkU3wq5fQ0SLc4ICITw4zNQhAV+ZcOu5MCT7aNqF4daPgdC2HKh8GJ72GivojSNibQ4b/6Gth90m7bvSGiBqkcaKsBZuxvSzMESrB8OczccksS7a2WECLsn+yvu24MKw2wzrRvC2qYcYx6Voe+klsahbS16ZVpmFYVaBEJQe3rWmtECEImSBMIFsENM5kVpvC5oxEAYXvJVrxsrbhZUsDCWmgygrBy+jG8mKrzTOQlS8FGH/3ExfSSkqrcCjM7aK370Dm/HxUZeydov19u12efCaXgxEdKTg1tKIgQCzlMMuDchTxkBI0dpFwfQfHx2Jegme+VhyCzzzbdjtNs1ZSvUQUjnFGAjFtmVphwZ9adVQAXx0YfgpBqIqb/0yF4bk+NZyRKJFaoL+ytIsjEqZvPqDryGUqJwLQ88C4eFCuRvPfb9T/N1uM38MlOfIUhaGzWoQpQUXhuLk1kuIUU2MpqwDYRYhBiKBM7bJKlFqdEtzYYQBcjes9gWidfFoT+dtoEBovr2of7By4UqbRtpt+i4M4c1CMRgZoaW9K5G+uaUnx+q2tZJWJ0V4uBhlYWw6LJ9LQniGm7JA2AULhBBE6RIHKaUFQln2Vt6P4MLgWyC+3XgUGxVzXujhtUDIz7PS6qS1j0IlxrjoCDHdVphALMLGr89vhFBMSiy0ZfLa1Uvj5JayjlAEUZq4eD2FpKyncSpvjKoM697UWrkLY+GILtxqllJsUN+zsrkwfNQgjAtJeTredlRuSbUrUpj1UM54a2UyLat1IMxMpiWgdBnGRNmRGMtXIKwa5ASLpbI/QKkQhaYJghQIE5gNorRigYiNihBrB2hu18JFY2VgUFkgLKVxmrvxtLR3JdI3gX9WTiKjhdlUPCXSIEq9oV15PoQgSDMWCF4QpfcBaZfd/8Lx53VrlIUx7sc/DWWRIh5epQVCZf3i76MwqMc7IpFSS16vI8KurnpohnylC0Mvg0hy/fCamQ2i9K2QlFqBaKcISBawHkSpvc+FF7wKjPTetEG/Fgog1HoQ3FFyF4Y0Q8MqLpMujPUHz6jkMTO8REXYVBYEwRqnhXyGW4uVKE3Gw/DaOiIjVNeG+FJg2QJRNRcGBVGGAabTODUtEMobw27Kf6xpgbCwbR7KGAgrDwHzFghPS6N0a6l2Hx1pxz0aZcIB48FTC1kQpV6wpeIYWomBiFLEQJRVuGUWCN41xK9EWenCqEKxJjlec7j02jXrwnBUKhBx0RGop5gS25c0Q8AbAyGgd+0apXSaDaK05MIQszC8b4UrRnXFmNta4Okel/PlVMnN9+ub4VTllOzJcVGyh6rdZkPc/7d35mFWFOfCf/uss+/DwMAMs7DJIvsOKoIBNe5xCyZiFJdglOg1gXiN2QjmJjf3Rr88JrnfDckXNSZ+UZOY8KkXjYn3uhIxwQXlc+OCQJTAjKwzc/r+cU73qaquqq7q5Zw+w/t7Hh7mnFNdXV1dXfX2u5XEDwgAqA2z2LEerhNlllff76F+l2XIJeHdH1KDxIOsVVsDofF8sXNYOhFzmDDsPBAu1/qDy6ZRn602XzQju/3C5La6/G9MGCcPzANRApBjQu4DwR/A7MC23uq8+kDoOlHKjk/EY47F5KQxzXDgcB/cfFp+srQHqrIPRBa3hC2sdC97M9OV7i2sxSXjslERez9ECb64x8acGoiyRH5xo5worf851bpNmrqQb59x5r6TCBNJJfI+EGzG0LjhVQNxjNJkyerIUBoIZzm5E6W3RFIfcTQQ3c1V0H1ylfC4bz+6Db64bCzU5bZ9d9OusF8bRv6YDz/KmnjqKpL2/ghWGSUNhCCtNCtE6rD29HHyME5JvSpjhHd/ki4aCFYzrKOB0NvOm9VAxBwaA/ulwGWOqi2nnS+t+WLi8Fr4zzWnUs6ZqIEYJJD3TpoHQtEHwnpYXLUGQg2E8wevTpQpjgnjhGHV8OtV8+GkMc2OY5UTSZkAK//Pi7Cn56i0HGuLlF2GR/khL0CY8sRWf3rzA+oz6cxmccYk534sAGQiKcKJkng7JieWfB4IZz1ukybJP184Ge6/eo7DI5yEdKCTOlEKji9PZa+nMh2HxkqnCUM31TJA1oRBTnzs+CN/c4s4UXWi1NnO20sY58+ffw8eemmn/Vk3jDMZyz+H+w7mBIjyFHWfDDBcBSHSz8H2W7IHgTcnyj99YRFcMb9TnkhKUHGMEZ5FsP4PAPp5IPQ0EO6ZIkV+XNYW3GReGjYKQwQbKUK2eXhdObVWsKmsuW2Uni18UIBQgB4T+iYMoQbCow8Eb4zqOLOx4XyyME+HQ5ZqBlgT4PFX97gWc2ogxHhVwVrHyfNQOhngvEXOH9XELWvvhZGbAI4R23knGHNV3uFKrBJWoak6DXO6GkHWa6QqnTyfaN8Vdr22NBAV6QSkEjGor8irbrNRGMrNtfn7oWPUfZAJ0uR4MyB7PXdtehP+sG1v7nexEyVpOuQtUDwGMgAHc6G8uiYzUlgmF20e7NeJuGGrvz88mK2nvjLJpDh2F4RIY5ntAwH+nCitnUR58xWbB4JF9Zw8E0bSTSAg6tXNA6GWSCpnhmU1EEnLb4F4FnJtcXse2LHOc7C3UNNAYBRGCeBcUHmo+kCkbQHCu02SxasJg7dgyZqlugSrJmqRpdlmz+nVhJt3otRT+eWdKIn2CSZKNg8EGYXBJgSyBDZeTTpqWGvBkfUL4UNJRWGw40W2FwZAPguoFcqZPT94CuPMaiBIE4a4LDtBvrKrB/758Tfg1oe2AoA4DTwAPQGXa5gwLKFFNyTP3mxsIAM79x/OtUvtWDIayjJh1JanHGY13n4IFBwNRH5zLgP8ePPLfCBk2lKvJgzXKAwmNFrn2dHJpcJqIKxnoqY8L8zJXgpIWCFHJvSwqax5oAaiBKBtbeIBIhIIHE6UuQfRqxMl7zAd/wDWiVL2O6udUH3uZPsykLCqRKkPhEcJIm/C0NNB9BN5HOz2CZrAzQNhpbKOsxoI639O38fUw95sU4ikDGmL52mWyPPymDeqEeoqkrY5i/SDiBneNRDknCzVQBDlTMir93fuPwxH+gbshYC3uDVWpmDUkCqYMbLe9e3Uyqg5YJra4aUW/QMZME0TvvLbV+CvOw9AWTIG87r5Giv2OSLzsXyYu8b6iiQjeKpoIAgfCCutNOEH4we+BsJaOPnHZDUQ7ifmCXjuJgyibCKmp4HwkUjKDr0so7VxAO5jhtU4yK4xTWkgRCoI6elCB50oFTAEf7OomjCscof7BpTP6/aLng9E/m9VlbnJ/O+GugDh7gORT/3s1YSR/d8litMBLxmQqAVsHgjShBGPGQ5vegCRM2zWHHJMof/YDXx42G/wBj0+2QlJJPzO626Cl247zS5PCRAeE0ntP9xHCXLsfSVvEZlKO2OacJDY/XPHvkPSMM5EPAaPrj5J6Z17+sgGeOfDQ9m9KBQjKJ6/dTHMWrfJ/tyXMeGeZ9+Fe559DwwD4F8vngqdTZXcY9kFjPSTsX0gKpJU/8YMQ7ofTLbNeQH0+bf35frHnwkjX7fzu4SLCSMbPuxeN8/EFIsZ9qZvbu2pTCe0zH9KmShz/7NFebkbVKMw2OeMHN8sKpk1Pzx4DI72D0hNIWGCGghNpBoIwQ1nb64lyR843McrbiOSOnkPpE40He1M5xwCMvWeugZCrSD7cIbrRKnnA9HPcaIUtS/B1UDwfSCsOnj31wD10EirTpnWhkxlzWbuIyH9RBxtIo4jBYi4YehtxJbjWH8GDh/LC8+yt7YxLfnMi339GWr78Hc/PCR1orTqVtHOWRN7htgN0+3aSHMOQHZRemDzfwMAwM2njXHsQUPCOvOS+Vg+sMM4U45shG7ZWMlL3XXgCGz4z7dtk4zh0YmSPD+L9QYt1paqmTDKBZoVmVBAXk1lKs7dMkCEig+EBc+JEoCO0olJnmkSdqzLNBCUCYNT7dCaMjhwuA8efcXd1ywsUIBQQDWVtehtTKSBYBPzOM/L/573QOosjOSEyvP6J58Xr2dS1UA4wzwlehePs19+cdTzgbAmed5GWCxsHohjRB4I1gdClrXOMNR9Y6xSsn6xQ/gMgxIy2UNU1fVDHBoIbzfFest2Ozd5vw71DVACxHv7DkmdKHWwngly3OpqvPoGMnAkp1Wc1l4vLcv6CCXi+TFiRYHUlSep0EkDVDQQ9Hi97/n37Kymybi6eYxbN2fsW/OHqK+yAoR73eWCN3GZWYKstzLn5KuKig+EbcIQOFF6MWGwwmxa4txLOVFy+n7RuCEAAPD/934kPWeYoAChgOozp+oDYb1hnjpuCFw0Y4T2eXnP6q+37FJqIwBjwlDVQLAhYS6oChCs/CCb4LzGsFPbeWuIWgO2CYNsBL+spcmhduO0Fre4QS8E1tsKpzID1M1KMmdMCzIYgM3cR2Kd0u3+UhqImL6fgIW1HwavLSJMM+9gCJAVIGROlF4g1eW6eUcGMiYcy+VtcFvMWLV8Nt05fb76ypTTB8JFA8Hud5GMxew2JeMxH7th8AXepIsGwn2EZhFFyUifBeKklel48D4QVhSGyAeC40TpGoXBdJQfDYQ1NorpBoEChAKqTpSi39hBYscMxwy47pRRSudVOY8q8pTGQI1I9lTqPhDeojV4V5YPQ/MnQGQymlEYA87zilrgiMLoI30gYlwNBPdyDPUdGHmmkJPHNNMJwIQmDPocqn1Las2yTpReNRB58x1bxy9eeA9+95f3c5/oG/a33nyo5LsfHpQ6UepgObzSGgi9OvoGTHvcuwmB7KKUiDnzsdSVM06UAIoaiPzneMywr0lngeXXzdFAuPjAKGsgBCYMmVmCrLYyFbwPhF2Wk8oagNZAWIKB23PEdpfMB4JyouT8bp+qiKGcKEAoQI4J+Rsy/3v2TYZeTKRn1jqPKuxeGCw8DYS1iB9xcfy0CNKJ0sKrltp7HginBkKkBbFNGLk3qaMDeRNG3OD7QPAmGwPUkzPxji9Pxqk+zIfAGrSnO3Oo7U/h0sfOKIwgNBD0bz9/fgesuu/P3Bj3vbmtxAEA3nVxotTBGq7kuNW9toFMxnZ+dVvMHBoITkr52ookLRgZagmx2CgrSgPho5u4GggXE4YRUxNORRoIKxyW3x5SA5HQzESpnkjKkRuFkwfCyJ3abcywz7YfDYQtP0jPGC4oQCjA24qZX04NXlIhbn0aGgidiYEsy/P0lQnn7+07pHSOvn7VjFM0flPi8o/L/p8x9RKv8LbzFmogLBNGnPCBsHatjBnUm5odl8+pLGaoeV+L2lKRilPtJTUQn5zVDgAAn5zdLnai1DJhGNq+B2y2RfI7lr4BE/7jtb3Ud3sJDcSOfYfym1b59IGwxgXpXKc73voHTOJtX36sQwPBZIStSMWzGQ+ZKIxUIiYdH+zul1kNRPZc6YSaADEyF9LKwleju5swVORhK+OpDmwUhpYPhGR6WjiaH3prYWsgCBOGqgaCFXTZ9PAk1F4YAZnoggYFCBU8aiA+dyrfPKH6hik6Fa8Ndy+fplQngLsGgt7AiFyMTPjvv4vfCEj6PG4Gxdfqq70di7DeCnWjMPICBNEWQRvyGoh8fx7py+WBiBkQJyb9jMRun92gStUHwtmmMkbFTfpAzO1uhBf/cQmsO3eiwxau+rZdW560JzYvu3FamSz/riBAPLB5h+M70oRBmsn8OlFa2iZy3Go7UWZMW3B229OE1UAkGYfUulyIILsbJ4BcC2EY9DOU8OADMbuzQVC381jrbZz8hcpWaqhtpiUyYcigTRiaURgCCWJYbRm01pZLj+XmgVAUINLJvBC3esloaKkpE5ZNupow1IT+MME8EAoYgr+d5fK/3r18Gpw+aRi3HKWBkEx84jBO5/cySVZ2ft5ixTVhmPTk7YZXDYRssvGqgbAWctDMA8HbTEvUBjuMk+jPQ7mUyPE4/aYu2+XTAJ0oDKdgVcGYMPLXm/2yKefDwMqwqgKEYRjQXJ2GnfsPZ6NLNH0PGipT8MFHxygNhKhPt+484PjOCnFk8auBsBbsfh8+EJQJw0UDwZruasuTlFmnNrcpFx3Gmf27MhUXhoAbZCpKyJpGDuZCZlPxmJKaVJhCn/Od9TZOPrf1lSl7x1VWoBGhmmqchNxorDKdgIzZLylNI/KBIDU4pmlyNZZ2FEY5GYWR/V/2HBlGNovlXZdOhYGMCedMEe88DMCYNySdWMwdOVEDoQAdxqlmcpCVU/WB0NFAqBtQ6ON5G+SIyr7zoZr5AkDdB0IFv06Ult9GhsgyqAJvoRc1IZ9IihQgsueNG7QAYbdBYIrSdqIk7j2bEliUBlzFiVJ0rU05YTVm6N+Thsrswkhu6S2adA8dc/rb2CYL5hCvPhDfvWgy3HDqKJjekX3rJk0Yuj4Qx/rVnSjZBaypKk1pJmtyOQboMM4sskgMdsGOxwxbmE8mYkqzhI7plJeJsrEyv7mbah4I1d1SScjxUZ7Ui8LoOeIubJgmX9Dg54HIXqPstltzwMdPbHUVHgBoMxg3YssWdFyrCg0UIBSgNBCKcwo79/zymrn238X2gfCSSMoEgHc+PKh8Ds8mDMl1eH3JtN5UdPNA6IQI5sPZ8jn5rWRJbCIp0SIIoOdEaR/DmDBos5NVrxw7JEyhf6xIDC8+ENabJplpU7T4HzwqdthlMzx6deY8f9oIuOljY+029BGqbV2789H+fHt1BYjGqhSlRbEc9NgwTgB5JIYB9HhIxvNZTVWdDIX+DLzxan+Z/3F4XTnxuz8nShlkJl/d3TiPCTSksRh9nbxwT+sloT6nJUrEDCUfCN0x6uZEGQVQgFCAvHmqC74zzp5WK9rHeGmP4nciqDwQnIdO9Jb+ro4A4dWEIbsSrwKErYHQ81jOJ2EimiA0YeS/tyYYa1fHeIy2A9s7ZHLqMcFUdqK0w0GJ71gTBu8ayGMtdCa3EfXZBaIqndCeFPPOmu75Fg73id8Sx7fWUp+9ChD547P/9xGOr7qQb8RuizUbGthYlabeXm0NBLkbZ+5OyzYGy2abJF8QDMqxU00oypf5zzWn2n/L5j5yzrB277TarNKVvN04AQD+6YITJUfRfagiIKlt0GXYtfMFiGxbGypT8KUzxsHtZ0+wx7B8ryS95Tbl5gNBtLNYoA+EAuQD6aLwt/9ixwrrGZ3/XiKQCM7Gm3B13pZItWiSU5coE+W7BTBhyDUQ3haJruYq2LJjf9aJUkMF0e8hCgMgayPtPZrvR3YxyoeHCoQRVSdby4RB1FOeilO2YduJkmm5IwpDY8G87pRuGFZbBhfPbLPt66rw03fzkWkgThhWDb99Of/ZbxindS/sDdQ8VEcJEJqJpJqqUtQ1VNkmDGLs5aqUbTPOdkMynneiTMXj2iYMUpsgg8wQagmYAJDbf8O7CaOzmb+XCA8VwXt8aw1s3/sR1V4S8r6bZj4XDEkZ4Sh99Und1G8ywTMIDYRhODdHQyfKEkI1CoN9aERmA6kPhNCEwSkrrsZxLCk08KMweE6Upp4AEUIUhq4A8ctr5sJfdx6AVCIGW3bsh2P9GS1pPb+TJdEWQRPIyUuUOMyC55xJnlPfByJPeSpBOdiJ3C3Y8clbgEWtaKkpg2tOzk6cR/uPCErxYd/0AcRvUOR+GRWpOLVAjx9Wk2+n4d+J0roXlg+Dl7C5wzmNU8xwXyzYMM6mqjRjwrA0EE7hVeoDAXSEAZlIKplQS2UtKiJ7/sj8ME1EsjHDUE0kxRcgdPLkWOZD2UZ0ZckYTGmrg6e3f8Cv0aD7qI8TrSHbuErWR6qaRQsqjJPI6Zl/KciCTpQRR9mEAeJy5CdVDYSkRc5vFKthwwR5ixUlP1hSLuj5QBziSPgqNsogfSAmt9XClQs6bQ/73iN9nqR1OhOlQGsQJzUQ/M3TLEx2BiB/A41U1pwKKpJx7vfsN+xnD5tqZo/THL/W2LciY8okmfgOESaMSmbRJNXkfrUPZLusxddLnZaAo3L/2Em/kdVApDlhnEQUhgjDMCj7fiJm2BqpVFzNiVKYFEpyMPlc1ZLRCYoaCJEPhKrTuoXbHJNKxGFae520TnJh5jpRSsZskBoIyoTB0TZGARQgFFC9ZbSgQf9GayAMYTlRfXRdnLIaDgKk0KCcidLMb/KjwvqNrzu+SytMrLwHhMymqIPVJ9bbXM/hfvBiMVTZTI3U6qhqIHhNKU/G1cM4rWJE8fJUnNtGtu9UNBAq6DpRWs+B9cZanowLhbpDhAmDXTSrCYHCr/aBrMOPCcNy6lOxxbPXXJlKcDUQSU4UxhmC8HCrDGnCMgxCA+HTiVImLJICUXdzFXV+PxoI0bgUNcVVgIjHYOrIeuHvZLW/3rKLm3lXdn/lURiaPhCce0+2z+qDHz71lpZpNkjQhKECKRjI4nxJHwjHhJ3/27cPhI8ojKwJg/CB4AoQ4uPrKpKw/5B8G3IRqUQMwCWVhOwydCd1q7wVr9171KsGIv+3qAm0BkIuQFgCGulI97MrZ0FFKgHlqbjyRG9VS5sw1LzZ2fHi1QlRd/G2xu5hSoDg3xTSZEFqIMqTcaq/A5Af7DqsxdaLUGIJ2CqaNvaSY0xESzXPByL350ljmoX1xmK0AJEx8xt8ZTNRul+XqISqBqK+MgWPff4k25nYjwZCtssnD3LRbahMUflGALJ9MK0tL0CQPgVWvVbVb39wEDZu3e08t4fcPQAefCDIcWQ5TJNzEfHhhXf+DrMECcDCBDUQCqg6UdI3V1zOvwbC+4xpAJ38h2fCIDUQrBAzslHdqYlFKcxKasLwtlhZGeN6Dvd7shaq3FcdHwire8mkRbM6G2B67s1IfaKxJpV8ebEqmP7sx4mSxGsqa8u/oSwZF/tA9PEFiMp0gnoz9bPDpN0u2wciJ0D4eMZ0NnWyMIC+73wfCJXF36D8S44NZOwXgqSiCUOcwE58DCsQjWmptucKP1EYopd2A/hTxe6evE/O9y6Z4vg9lYhBbUUSugXOmVkTRr7mt/+mbrYFkGvydJOu8TUQ/LWITEJWSFCAUEBVMKCOcXi987UOujY+0ffK851BT3DWIP3Vdfk8FbK39DrCvqlLKhGDdedNdGme+EJ0J3WruJWzvvdIn1YiKf55+W1ISjQQ7CJrp00mPLxJrZB6Vkhni9i9MNiyFo4wTo8LptcwTks4KEvGle4JacKoTMftiAQAjbEva1fuOkSRMzq4ZaHknt+gEy7l80Bo5gIw6BwHR9jIEF99JT5Ydgdd0zsnnDuRuh2r5gzqLGTNd9Pa67n1s5/JzdtUkPpAaA7UJEf7REF8J8prETYoQChA3jt1J0rmN4P8TVEDIXhguQKE4sxgAL2gWVLx9JF59RepUnZ7c9UhFY/B8tkj5e2TXJvuQmEtpJYGImMCHBSEb6nUk4U/VdJ5IOi3KZEPBOktT2oA1IVUZ/nyFD9Ujx0fIhOGri1VX4DI/n/EFiBiUpOZBRl5UJliNRD+cS4k3uvypIEw6L6sSotNGNJ6gF5MSC1OKhFTcgESmw3cj+W2yeU4mdlN1JZPzmp3rZfXXksLunzOSDhxRC0smzjUUYasd08PbXN1G+8yTZ7us2IYhi1E2M+voIqjKEBEF5U8AGw5diCJhIZCZ6I0GA0Eb7KrlWgZ4l7d9UExCoPznSgdsyrpRMx+8xDtISCD1DyK1lfyTdHdhJETIDgx5tnzqWogsuVITQa7nXe+LP9YURt5ZXjovlVZ57HaXJ4S+0CQkBqIqnRCqNHzCvsY+DFhKDlRMp8NAxgNhDiRlAw2CoP0I0nEDCUTnljzSf/QTkTCyO6hW1/KslDyunL8sBr40pkncMvf9vHx9t+8cWHNQVPa6uA31y9wbBwWM+he/huz94rbvZXJCF4ES+t8PG0j2dLGqhQUAxQgFFDQYDuQaSBUTSKin3wJEGBQqjFSdX7npVNh4egmWL1kjPB42TPwlbPGi3+EIMI4vU3qhmHYZgwvAgS5+6FomiT7VNWEwctyB8BfzHm+BtY3h5k9AVy0nQAgNmHoGnh0fSfYSb1MElNPQvtAxCnbeBAaCFneFl10UiqT5xsgNFK2CUNTAxEz6DThlqYnlXOgVDEXieed/N/nTxsO/5cwe0rr86GB4AkBp44bIszFMLwuv7ulTAMhPh/9+UNGgHDL5SATqL2YxSxHSp62kfx7bEu1dt1BgAKEAuRtV9UYOMLmgP+b7K1CywdC1YRhMJkoCXvt2ZNb4WdXzoZ6YjMctlbZQ7Bifqf03Crb7er6QKjO89aE7CWChHTwEmogOH4lFo4ojNz83i9IeMNblHkTH+tPYLWD7wMhHo+ic4YBK4CWpdR8IKoYJ0pqLATQdHbi9+UDoRTGSV+zAbS2wDZhcBJJyTDA4JowrDGpYqESzjtECy6e0QZDqvOLtaxeXxoIzWeezt7IeY5cNQh0um9Wxk+5CLyy58jLlvNODQR/PAThB+QFFCA0kZowpGGc/DcJnUxrsmN07OaUD4SmScLfm5n726ZcA8Epr3hua28BlV34WFRCI8k8EKphnKJsnbxJkydAWMUOM7HqfAGTpqqsOBHc7LWVJcR5IEhILVCVIwojgHbF2OfVe126GQcBss8VeR+t9iRcFkQWwwCY091of7a0U9b4UdEwqTgusgulLBuiHwGC+9JgnZNzSrfwXlkSKAD3++62e3HgGghik77s//xyQUQieQEFCAXIcaq+mZbsN7GgITqG/t77YDEYCdttslOxlZ8wrMZl05ssKrZhqYCm8GYtwsoF4cmEQWoghE6UpAZC7kRp54HQMGHI+k7JA5up8qbTxGYqwSFC/uOmk+Dsya1KZdn7VZ6KKQkQVWkyCiMR+BuXyKnUC15s3THD4G5frhrybWFA1sHwY+NbACAvXFrPuZKTrILmk22LXAMhP53chCE/loUU5FU0EGy73eYTNsssi0wD4WVcsFpbkdYBNRARhhxkqj4LsvAgv6on3iGq1bDldAc1T8LeeONCuGhmm+uxKiYMXofk98JwFled561IDJ4AoTPBiU0YYg0Eq+WxBAfRhmM8oZL35qQzdtiiTVVpmNBa4yjnxX4/aki1ND0wuQiyC3NDZVrJhEFqILIChHyhcGNsSzUlRLHtCt2J0uFFyU//nnBZEB0Y2Ws5b+pwAKB9IADol6HfXD9fUIVI8yl+8ZHdQVcfCKkTpVjdyjsnmXyJq4Fw84EAeXv9OFH600BA7n9yHeFrtQsJChCa6Pga0MflUfUgF/3ixxcADPotRDe5iR9bud8oDL46U609lld7b06AoLZXdzHjUD4QgjKkMyo7SbHV5xNJiTQQzu94E5do7PDWY15ZngLkf106jVunG7JxQS4Q7CTaXJVS0kCkEjH7LZrURgB4mzwf/fxJcMPi0fl2OQR+/TotvGkgAA5x0iZrP5+562CjXZK2D0S+s4fVqu20aSHSorof52LC0AzjlNXmJnCpzUESLYKHHB8WnnwgbCfKnAmD+I12A0ITRmSh8iJIekz2gNFCg9p5xRnhvOsgHBoIl8WTLe9n4yIlE4akeu51q2ogGBMGFUrr0izZG5IFOdGzk5RDA2FHYahrIHj+I6L5iKem5gpmnHKTRtTyK3VBdhvKiAWCvbSmqrTSboKGkddCsBtrBTF1BhmFoaJRYa/ZMAxq7w8La+zo+DgBOAU1nhNltcAPRnQusk9Eoclux/EQbeWdPVZ6qAPe3iEk7LPJttvtzrnNYTLh0ZsGIic4WIcKqkANBAC88847cOWVV0JnZyeUl5dDd3c33H777XDsWHHSdPKQ3ScqD4TDiZJfzsu5uOuoR6FEV6MQ9+AgZuFX+vchP9hOlB/mcuOT4YNuQlF5Kt9u0URJvl2wIWZsF2dsAYJfl2oUhqiveLXyLtFtndOalCSFyR032b5urEorJZICyOeCqEwxAkQAsyc7ufvxgRDdVxls+KVFwk4kxKemLAHXnNxlf7b6wiFAJJwCRDoRg2fXLua2hYfo7ZetV7U+C1EaawD+syC73aQgzxXE4y4vA6apHOXBoywZhx9cNp34nC+vq00iz8eTH0R/F5JIbab1+uuvQyaTgR/+8IcwatQo2Lp1K6xcuRIOHjwI3/nOd4rdPADwnolSVE6G2Iky+DpVy6toEUT4zQPhR3CqYZJjpZMx6M2FeLsJUeUp98eEXMTI60zEDMcCZykeRJMuT6Dh7WQqujdcEwZnhHhJ6y1Cdb1ln5+mqpSSY58Bhq3qDkMDwfa5H6FE5NtC4thMS3A+SwMh+n3JCS3QUEGEXeeKjainzRP22CDES8MwYGhtGbCIM+DKF2cRwUdhZL/jDRs6jNP5OzsHtTdWOMrIWqsyh5HZLUlHad2IN/J8PB8IVef+MImUALFs2TJYtmyZ/bmrqwu2bdsGd999d1EFCPJGSe8TpWWgfyIXKdXJVm83Tn9aDVW82PEsvPpAWAThA2FBagncrqlcIQ8EXTfhyMWp223hVtdA0FhvO6oaiCAFCNl96OvPn4e9ttrypHJ2xPmjmmBv71GYOLzG8ZtfnFEY3usS+bZ4IcGqsBlM5jfrz+7mKmipSdupmJMS4ZJF5cXFYcJwr1aIzIQh1w46z0qaZPmmQPrGLho7BG46bQx89/E3BDUyx2sODFqA0B+odh4IcI4DVef+MImUCYPHgQMHoKGhQfj70aNHoaenh/oXJlIVO+UVy7zRkH8zv4keIBVbJK9+GYahls5WhB/VrvUwVKfFcquuD4RuFIYFpVZ3eTMoVwjjJEkzGggWUfim3R5V5y+mWGtd9q1TdT+LAOUH6X0jVfPstSUTanthAAB87ZyJ8OfbTuM4/wVvwvDzRifybSFhLzlmGLDhipnQWlsG91412/7eGj9CcxWbkMow7P8XjMpv+62TB0J05aJIMl47qONcHlKZCYPnc6ZswlBwRjYMA64+KW8CMk15/apRShfPyEalfe5UwlHXjxNl7lBhjg50onSyfft2uOuuu+Caa64Rllm/fj3U1tba/9ra3MMJdVEO4yR+kzlRsuNo8z+eptUeL6r8xlx2ybndjfbfSudiBqYf1a71MGxcvVB8Pkn9/IgutfawJowyKjJAfizpA6ECKUCQk8YnZ7cDAMA/fGys9HjVKAx2jI3OpbNVjcIIUH6QTl/D6/ILPnsPk7GYogkjV96PakBCkAJEnwcNRMzIvg3/19rFMH9Uk/29Za4RRSrEYoYwnG/B6HxCKV4UhhDhIiUuIt+NU346r1EYriYMzqiUJWQDyL4gyOYUVQFi3XkT4ZHPLYAbiUgfLz4QluNw2o7GEDCYNRBr1qyxExiJ/r3++uvUMTt37oRly5bBhRdeCCtXrhTWvXbtWjhw4ID9b8eOHaFei1SAIP6WJZJiKxE9QHo+EPIR9PCq+XDL0rHwzXMnwaKxQ+CK+R3w3YsmS4/h4Wf+th6CEfUVDjW0CtxEUorHsiYMSoBwtdESe2EomTD4YYvrzp0I/7XmVNecGbomjLWnj4Oupkq4TbDBEAB/EnczYei81fAm+s/M74Rblo61kxoBOK8tHlPbn0FGGCYMP2GcKhoI5/n5J2yqSsM3z5sE//QJOlHbmtPHwfC6cocwStYyvzsviBw6ls0xodLVomuXbWvvJ5W1rglDdglkYjyepo+XB0InskzVhJGIx2Di8FpqvHvxgVh5UiesmNcBp08aBgBOYceiQJnoHRTEB+Lmm2+GFStWSMt0deXVSLt27YJFixbBvHnz4Ec/+pH0uHQ6Del0OohmKqH6ZiKLwvDqcCb73q1ZbQ0VsGrRKPvz7WdNUGsEaxv2MVtTb+aitxypCYN3gNq5pSYMlzcDlURSJKwTpYVhGLaZQYZ6KutsuWtO7oZrTu7Ot5EzxepEYYwaUgXb934Ep44b4tpWGfO6G2HJ+Ba4c9Ob9nfsm34iZviyywOE40TpKwpDQQPBagJkp7M0VyTXntwN1+buOeUDQXwYUpN3kHzx3b9nz+vaMongKGmjzLTnNmXInCh1fXfIjLA8AUK2pwxAzoQhrJ1OVKWLlzE1bmgNfOVscq4WzZvFkSAKIkA0NzdDc3Oze0HIah4WLVoE06dPhw0bNkDMx/bRQUF5LkvK0Q8y8xupZlSc8nQEiEIRVCIp0cOkvZmW4rkdURg6YZwKiaRISDWql0mDm4lSwYnSwm8Uxn1XzYZH/vI+XDB9hLSdbliXQZ6Gyr9hZMeTmgJC3I9BPA6s2SHsKAyWoBYAtpaupkp464ODMDK39baSuUiogRCX8aOBkPlA8J6fjMRphtRA8DRB/IRs9GdZD/mJQvPjgG4hdKL0XbM3ir86E+zcuRNOOeUUaG9vh+985zvwt7/9DXbv3g27d+8ubsMoHwjJREbcRnkeCLXTCk0YGmWDxunboX6sigAhTyrjPEZVoKlMxam20k6U8jrIdutOwF7UlsomDEGz/eaBGFJTBp9Z0Am1jNAlQ6b1oFWthEo3Nxn7NmEEMH02V9NaTFVN2//+9Aw4Z0or1Ffk+2r6yHrX44L0PyFhn5F7rpoN508bDv96yRTl84qunE6jrI5b2QpJmDTvmbfkB64PBPG88fJxyDR5Vp0yDZKXVO8WfnLoWIjvje+qPRGpMM7HH38ctm/fDtu3b4cRI+i3H1XP8jCg423F5VSdKFXvtVYURkgjiK3WT8KdjsZK+2/RW0kLoXY9Y9JQ+OMbH8DlczuybXEPRBBiGAZUlyXtTJRlAj8F7rGK5+CV96JA480zWomkmGyDvUf6lTNRhgF5GjqFuHrPqjoveyUeM+CBa+fChT94BgDU79uS8S2wZHwLHOvPwKvv98BT2/4GVyzo8N8gj7B90VpXDt+9aIr9WSXiRU0DQRcKaztv3hAZkJyMFL55goCKACATaIutgSDBPBAMK1ascPWVKDaqe1c4nCg9nEsYsuNjIfWLbJtyGX/4h1OgoykvQIg8kocRiW2+ed4kqEon7DdVvuZF/cpryhO2AEEe11ZfAW/s+Uh4HO245A5ZtxcNBE+g4SaSElRNvvHXlieh90g/d1UIW3zIJ/zJn4knQPjXQAQDuTjoTsipRAymtNXBlLY6tQMCDaFVb6uaBs3dvKjTO64ChCTKibuHS04Kcgup5powXAQIkzkulYhRu9360kAEYI4vovWaS6RMGCUPqWWQaCB4fO2cCdBSk4Yr5nfk6xCU9bWZliZstew6pqrqJYUHAHF/DKvLCxADGZNyivLjAwFAO1J2NefbM7NTnGckew49CYJspky78dlTsk5wXz+Hdmj1mkjKbiLRRssMwSsbZCIpGXQiNp4Jw70O2X0OSvtG3quw3+jCCqH12+z1509Siv5y+kBInChdw6T13mNVx21TldO53k2DYJom5XzZzNQhM7G6EYgPBHG3o5BIKlIaiKiiOs/KNBBuK92n53bAp+d2wD3PvkvUIXoT4H1XmBEUMwy458rZcNm/P5f77K0e0cLaXJWGxeOGwOG+AWhg8lX4FZxIAaKtoQJ+e/0CaKxKwcNbdkqPE4VOCcsTf8scvm5ZOhYun9dBmW0A+EIZb+dE8W6ctAYiW9ZZzsOWDUJkz4jIidKaUP28FQeJl0ipKECZFlzmAbeujhlqcwlbRp4HwrsJg4fbuN1wxUzYtf8wnDCMs129ggmCFCCaqtOwc/9hAAAYN7QaLpnljIhRxUseCBaxXxyaMEoemU1KdUIiJ9Mgwzi9wr75xgwDFozOx5d7jcoQSeOGYcC/r5gp+I37rfI5yVwQqXjM3nkyyNTDAPRi99YHB6XlWOEBwNmnHY0VlMbEPl6hLTINROg+ELmTCp0oLQHC72kCGvs8B8+wCLLvqTT5Ls12E4ANMJQ0EOy21rLLcbs9ugKEWybXRWP54cepeExpviKDaKyke6sWdcMtS8epN5JDMBqIPFHIA4EmDAVU3joBmG2/WR8IxVmOUvcK7g6vqrDGD7uzZFAZ+xbksu3pDHw/qawB6FDONBGF4bZ7oih0SoTfh5nVQPzuhoXcDIzCKAyeCYPnA1EY+YF6YyRbkdDZnyGwVokhx3YygLfFQkGZMHxqIMAQ93U6EYdrTu6CT88d6dCIyeZI9zBOvWXImmd1x6+qA+UA4QNhaUG9+DKxBOMDIX7xKgaogQgQ2iZF31Avt1d0DHesFGj8sBK818XysjkjobYiCTM75P4HbufyasIgnRL7NeL2fSXiUYSdZyrTCYHwJDBhEH9HwQciIxCsdUwYMoLTQOT/DmLBkBGoDwTleyUv6559VF7H2tPF2U6FdUrqm9xWp734eTW9saG6PPoHTCAVkudMaYVXdvXA4hP8JVUDCD4KAwSCeSFBASJAZCol1Td12l6sYcIomA8E/dlrxr5EPAbnTdVLVOT3uikTRkJHA8F3XBKXV24S/3jONen0M9nGGokPRMECowUnsmzCKuYCaRhnQGOfZ14pBXRCxN3Gr2EYnvpTasKQVOclnb4VxqkzfuePaoSvnu2efbcvk6E0EAtHN8PGG9WSILoRqg8EmjCii+oLkmzxl3kwk2Qi5gMBkM1mZ8Gq1wupOuOabjROTwoQZPIat8yBIrujsLzPLuG9JfLWWKEJgwnjBOAvsjIHzyCwxoboLJZK918vnkKF7/Lr8vabDrQPRLjjOkjljyxDpJfztjW4p1tn8TqUvPSyF43VvO4mGDWk2rVcX78ZuE+URRBCaScxF4uimwoJChABInOi9HJ/VRK62GX1q1dmBRFaGpQJwwt+NtMCADh0bMD+m3RKdHPK0r13ft+Iee3hX7vAhEEc3lSVteGmOXbmQmkgSEGFbLPlZzBxeC08s3ax5/qDGoKFdKIMEtkmVyxuAnDMAFg6YSjcdNoY+OlnZgXQOjleFj6355WHqgavP5MJzbQ3psVdgHHj25+YDGdNboVfXTc3gBb5p3SekiLixfzgcKJUnOZUYnu5i0mIEihZM6uBKOa+HAB6103GhZP59922Xybv3RkTh0E8ZsBCIhLF2SblJnHhaiA0tE7kG9opY4fA1Sd1wQ3EtsL5ct7b6KiL850h+Q1Azywje36CGvuk20OyhEwYZNe4NdvdhJHtzxsWj4aTx+io7iV5IAJ+vfGi7VB9+z/Wn/EkoMj43Q0L4O7l02Bqe73vuobWlsFdl06F6SPVfcfCBH0gFJjT1QBT2upgTEuVtJwoXC37We1cXtOThjrdEe1gfctkbfzY+BZ47NU9sHqJc/EKCp2144Lpw+G9fYdg0Vh6Yhxw2X6ZPEd9ZQpe/dpSSMVj0Ln299zyZJ987tRR3DIyeBYV3mIrunRyDJUl4/ClM/iOb8VIZU02OhmQo2IYGoggPOZlqEZ2qUD5QLg8EGG9XXutlszyqIqdiVLjnKrCat+A6eoTpcuE1lqY0FobaJ0Axd3ewQIFCAUS8Rg8vGq+e0FBwhwAjTBOBR8IHmEqAugdFBnBSDLP3vbx8fDF08dRPhRBo3Pd6UQc1pzujOV2s3myp2BDW2VtOmWsvgMWL9c/L37dr/YnZBcI7mZaJFoOuLKiAY39Ug3jpBNJyXGP4vR23V6H0tH+AfdCDBnbiVL9rFEwYQRNFJqJAkSAyBwgvTyWOseEGYVB1u3YTEuyiCXjMWjLbSMcFkFcd5/bSqrtA+HjYOA7N+qZMBTPUwwNBIGOo2IhlnMqxLSknChJDYS/83qVSaWprCV1etJAeOg7VWG7byA8J8rBCPpABIjUB0LVhKHgA8ElxPmObIeOD0TIWmAAcPaRylbKLDp5IFTQicvnwbPB8vpSmMpaOfGZVrO067KEOyoPBPE7LzmWCNk4K0UTRpBQEV6uTpTh4LVeLy8YtoYuBBMGgDcnzeMV1EAEiCysRj0TJZl0JxomDMrH22HCkAgQBXCwJBenjsYKT3kpXPNAaC5R5GV76QNVJ0oRUdFA2CYMwWm0nCglRYNyoqQEmpD7JsjqtQRWt0RSHvtS93o2rJgJjVUpbip393OFF4UBIN8uPEpEoZUoQARIEOPOax1hLtWUBkIjjDMMAWLS8Fr4684D4gIe+s9NA6Efxsn/WxWuBkJHgAi4nFfKczsXioRiHT8DaRSGh7bxSBPJxaIwOauikwfC7eU6jOATtsp/vnAyLBrnPbOji88zFx0BPOz8KEERBTmndPR0xwlex0S4YZykapf+TbawpRVyz/sliIeI1UCcykxuuj3r14QxdqgzXlxn0zJVzUJYGogvnTEOPjm7Haa21eXOwy+nIxTJLj+ooU9qIKIwOauiE4XhZt6q1Nxa22Li8OzOlyrhkhdM18tA+8dbFsFPrshvsOdFQ6CjgQg6CmMwgxqIAAkiNCuKGgg6zjz7YfywGnj1/R44f9pw4WGV6fCHVyACBOE09Yur50BncyXMWrfJ/k5XONPZXpkHL/5eyzTjIXNqkFx9UrfSecQ7snKOkQkQAY1+so+DDLPkEehunDpRGC6nrSrz9sz+y0VT4PtPboflc0Z6Ol5Ge2MFtDfmfSXszbQ06kAfiHBAASJAAjFhRFB5Sj56lgBx/zVz4OUd+2FeNz+hUqtLamLPbXHzMvfQf/2ETnR2VyP0HOmjz6lZH7mgectA6jxIywdCsdyI+nL4778fVq7XO3wnSpEGIm4Y0M88TPJEUr4ax6WUNBCgMd7crmuoB58EAIAhNWXw1XMmcn8LWjsaZiZKgMJFJ/klCmsFmjCKgGwy9KyBCDUKwxnGWVOWhIWjmz1vphUUQbzJnTI2a7Kor+DvXKntAxFCl+gEBaj2yYYVM2HR2Ga1HCdu55RMZqLmiMwy/JwXznKVOR+L+aPEWUG9EvZLaJDVy/K06J437LDrILAUhjrP/tT2OuWypWLCiIKcgxqIACmm5BpmHgidCcqiUD0RxHlWLuyC4XXlMLe7EQD8O39SYXUB3RYdDYQqo1uqYcMV4e93INrOWyR78nNeOL/7f6tPgk2v7YGLZ7b7biNLFN7uVNHbjbN0rkuEzjW8/OWPwYHDfTCsVr5BWH1FEv5+KKt5RBOGOqiBCJBgojC8VRKuBiL/d7E1DixBRHqkEjE4d+pwO6TMmQRMN4yTnNCDt8+7EbU1QjQfi66Jm7abU7StoQJWzO+0oz0CJWwNRJCJpMhZ3KcJIwyCnjEyGj4QtRVJyn9CxM+unA0njqiF+66ajQKEBqiBCJBCOlEm44brJlBBQS6CxZYf2NOTi01QttYUGz3iJ4wzoP7SubaoTX+iMS26Jp4AUehxVyp2cAA9gbWUNCsiwljgJw6vhd9cvyC0+gcrqIEIkGCcKNX4822nwfcumWJ/LpQGQiecsBCw1x3EPWAXMN2+JTUYxdittBiLn+yUdB6I/PciswxfM1HYfgzfByK4E+gIrCUkFwkJ+96UigARBXMUChABEsTtVB0T1WVJGFGfV82F6QNBomqLL9R0H9YCTYYY6ueB4P9dKCIwr1AInSgFfWPljyApdD+G3YdhZaL060QZBkHfu7AXzlIRIKIAChABUswwzkJHYUSFsK6bNGPomkYMwd9B8rMrxc6PViKsijB8AzwgmvDZfn3i5pPhzkunwtKJQx1lC6XJqcrlLlk0Tn8X1ShwPDhRDnjYzluHBaOzUT3N1elwTjCIQB+IQBmciaTCsOkHRVgLSyoRg0PHslsN62sg/OWBcGN4XTksHC1e4E4a0wy/um4udDZVBX9yD5BDWraza1dzFXQ1V8EDL+5w1FGoYffULafAOx8ehOkjGwp0Rv+gCSNYvn7uRJjQWgNnntga7okGAShABEghfSBYQk1lHaEojOqyJPWZbU5Qc0sqTmog9I6lywffXypvkYVeAGUtEu0wKxpKvLFcKMG1sSoNjVWl9eYZhd04ZQQ9N3U2VQJAeNdSU5Z0ZFONIlGQBdGEESCBOK95DeP0f2Yh1DbHRVZBrDtvIowfVmN/ZjUQl83J5gSY1elvAXVEYmgQdg+Vym6BFuRzUUHstaDjkFsMZ9QwCfIOUkKDqwaitMYOya+umweXzGyDL398fLGbEgmicCtRAxEg/QUKq+RRqO28ix2FMbKxEn5/40LoWPM7AHC+3Zw7ZTicMKzGfkvxCqWB8JEHIoz3hFLz8SInumpirwWRUHC4byDsJg1a3B7PEhs6FNNH1sP0kfX259VLRsOjW3fDivkdxWvUcQ4KEAHSF8DMrlMD7e1fGBNG1N4E2QnTMAwYN7SGX1gD2onSX5uCplS2G7YgHYOriA3WRNqsoxwBImrjzi9hvT267sZZWkNHSndzFbz6taWQYLcIRgoG9nyA9PUrblQvecaj+YBHx4TBEqYTpVfCDqmNogljBvFmyCLWQPDLnz9tBJw1uRXuOH+S/V3Ehl2kMNQtGIMikRTJ8Sw8ROFeogYiQMhdHWVUSba5jmIGPDqRVPHawSOst/2kDyfKsJ0goqiBOGFYDfz2+gXQUuN0QJzT1Qgbt+4GAKcTLI+GyhTcdelU2Nt7xP5u8AkQxUkkpThFISVAFJYKFCACxC219FfOGg8v7dgPSyc449yjDDknFTsKgyUs040fH4iwuygKEwePSSNqud8vn90O1WUJmNnRAJXpfG4KK0xWBJ1CPVrjzi+B3kONKAwECRIUIALETQOxYn4nrHCpI4prg06mu0IT01Df6uDHByJMfxSAaGqpZCTiMTh/2gjH9wddBAiSaI266OKeB6K0xg4SbSKmkC5t+voLm0iqUJMqFYURse28Y4YBtywdCwAA3zh3YmD1UgKE5rEYxqnGoWP90t+LnRI8TMK6g64CREjnRQpPFO4lChABcsuy7EJ2hY+wIp03hLqKlOfz6EC2KGomjJhhwKpFo2DbN5bBvFFNgdXrx4mSFLK8rvX/eOYJAADwrQsmOX6LoAuEJw4edTNhiD8heUizhWsiqUEydpBo3Es0YQTIzI4GeOWrS6FS4iQZJJ1NlXDbx8dDY2W4ggQp1EQtCsNqTjoR7L4P6biPvTAC6KKrFnbBxTPbuE6HUXSi9IK7BiLclODFJEhTgo6m5o4LJsGKDS/YWjsE8QMKEAHjV3jQnVauXNDp63wqUHsZRERnteb0cfD9J7fDV8+ZEEr9VBRGKGdwRxSxUGo+ECLQByIYdDZvO2XsEHj968ugLBmNjdaQ0gYFiIgRRSenKGogrj25G1Yu7ArNpOIvkRRhwgiqQQSDRAHhqoONxkgLh7BuoYqPEgoPg4XiTwQReZ90cvToUZgyZQoYhgFbtmwpdnMKRgTlB2rBipIPRJhtSRbZhDGY+c6Fk6G1tgzWn3+itFyhMq0OJrCbjh+isFZEVoD4whe+AK2tx992qhEYEw5IlXnUwjjDQsWJ8kefmg5NVSm496rZ1PfHRw955xPTR8B/rV0M41v9pxwvVf7l4imB1UX7iuDoQwpHJE0YGzduhMceewx+9atfwcaNG4vdnIIiy1JZLAYIFUSUNBBhoiJAfGzCUDhtfItj0sZJPBjo6ILBxaKxQ2Baex38+b39vuvC4YYUi8itVnv27IGVK1fCww8/DBUVFcVuTsG5amEnvPDOPjjzxGHFbooNqYE4XgSItGIYJ09YILsodRzn6vfNIM4DAQBQH1AY9iDsGkSBKJgwIiVAmKYJK1asgGuvvRZmzJgB77zzjusxR48ehaNHj9qfe3p6Qmxh+FSXJeG+lXOK3QyKYm5TXiz8LPyGYcDqJaPhwOE+6PC5rTgyePnK2RNgT+8RuGpBV7GbUjBWLgw/aux44bjZTGvNmjXwrW99S1rmtddeg8ceewx6e3th7dq1ynWvX78evvrVr/ptIiJhYNC4/avjJ5EUAMDqJWMCasnxC73L5OB7z25rqIBHPrew2M0oKJPb6ordBCRACiJA3HzzzbBixQppma6uLnjiiSfgmWeegXSa3tFvxowZsHz5cvjpT3/qOG7t2rVw00032Z97enqgra0tkHYjWfo9CBClPt0n0fRQdHR2mdSqt9QHJ0MpXc9gFASPZwoiQDQ3N0Nzc7NruTvvvBO+8Y1v2J937doFS5cuhV/84hcwe/Zs7jHpdNohcCDBMlgSF+ngVwOB+AedUdUYWlte7CYgxymR8oFob2+nPldVVQEAQHd3N4wY4dzNDykMx6UPBAoQg5Yh1YPrhWN4XTn8eMUMqBFkLo0SKBMGRxTe63CWRFw5Ln0g4jjTFZugTRi/uHoOzOpogJ9cMct/ZRHj1HEtMKOjodjNcAWfquCIwqwcKQ0ES0dHRyRTOx9vjGqp0j6m1O8aaiAGH7O7GuGX184tdjMQZNAQaQECiQanjGmGf/rEiTB+2PGTOTAVx/0Cig21iRvqvgcFeBsHFyhAIK4YhgEXzTi+IltObKsFAIC6iujblY8HcN1BEJooKOdRgEAQDjVlSfjLVz6GmSSLCJovC0tTVTCZMZHCEIVEUjg7IoiAmrIkbn1cRGgTRtGaMeh58LPzYF53I/zsSn6ofLDgjQyKuV2NxW4CaiAQBIkmpPanIoVTVVhMa68vWPp8FASD44JpI6A8FYfJI+qK1gZ8KhEEiSRlyTj84LJp0J8xobYcfVEQhCQWM+DjJ7YWtQ0oQCAIElmWTYzOrrSIf1ABMbhAHwgEKRFQ/YuUOhiOO7hAAQJBSoQ4Tr4IgkQIFCAQpESIoQCBlDg4ggcXKEAgoTC8DncIDJoYPq0IgkQInJKQQPnlNXPhtPEt8C8XTyl2UwYdqIFAECRKYBQGEiizOhtgVmf0dwUsRVCAQEodHMKDC9RAIEiJEMPJFylxUIAYXKAAgSAlQgwlCARBIgQKEAhSImAYJ1LqGBiHMahAAQJBSgRMwoOUPDiEBxUoQCBIiYAWDARBogQKEAhSImAUBlLq4AgeXKAAgSAlQhxVEEiJg2a4wQUKEAhSImAmSqRUmdJWBxWpOMzsqC92U5AAwURSCFIioAkDKVUevG4e9GdMSCVQCh5M4N1EkBJhVkc2w2cqjo8tUlrEYgYKD4MQ1EAgSInw5bPGQ3tDBXx8cmuxm4IgCIICBIKUCtVlSfjc4tHFbgaCIAgAoAkDQRAEQRAPoACBIAiCIIg2KEAgCIIgCKINChAIgiAIgmiDAgSCIAiCINqgAIEgCIIgiDYoQCAIgiAIog0KEAiCIAiCaIMCBIIgCIIg2gy6TJSmaQIAQE9PT5FbgiAIgiClhbV2WmupjEEnQPT29gIAQFtbW5FbgiAIgiClSW9vL9TW1krLGKaKmFFCZDIZ2LVrF1RXV4MR0PbHPT090NbWBjt27ICamppA6jzewT4NFuzP4ME+DR7s02AJoz9N04Te3l5obW2FWEzu5TDoNBCxWAxGjBgRSt01NTU46AMG+zRYsD+DB/s0eLBPgyXo/nTTPFigEyWCIAiCINqgAIEgCIIgiDYoQCiQTqfh9ttvh3Q6XeymDBqwT4MF+zN4sE+DB/s0WIrdn4POiRJBEARBkPBBDQSCIAiCINqgAIEgCIIgiDYoQCAIgiAIog0KEAiCIAiCaIMChALf//73oaOjA8rKymD27Nnw/PPPF7tJJcH69eth5syZUF1dDUOGDIFzzz0Xtm3bRpU5cuQIrFq1ChobG6GqqgouuOAC2LNnT5FaXFrccccdYBgGrF692v4O+1OfnTt3wmWXXQaNjY1QXl4OkyZNghdffNH+3TRN+PKXvwzDhg2D8vJyWLJkCbz55ptFbHG0GRgYgNtuuw06OzuhvLwcuru74etf/zq1twL2qZw//vGPcNZZZ0FraysYhgEPP/ww9btK/+3btw+WL18ONTU1UFdXB1deeSV89NFHwTbURKTcf//9ZiqVMn/84x+br7zyirly5Uqzrq7O3LNnT7GbFnmWLl1qbtiwwdy6dau5ZcsW84wzzjDb29vNjz76yC5z7bXXmm1tbeamTZvMF1980ZwzZ445b968Ira6NHj++efNjo4O88QTTzRvvPFG+3vsTz327dtnjhw50lyxYoX53HPPmW+99Zb56KOPmtu3b7fL3HHHHWZtba358MMPmy+//LJ59tlnm52dnebhw4eL2PLosm7dOrOxsdF85JFHzLffftt84IEHzKqqKvN73/ueXQb7VM7vf/9789ZbbzUffPBBEwDMhx56iPpdpf+WLVtmTp482Xz22WfNP/3pT+aoUaPMSy+9NNB2ogDhwqxZs8xVq1bZnwcGBszW1lZz/fr1RWxVabJ3714TAMynnnrKNE3T3L9/v5lMJs0HHnjALvPaa6+ZAGA+88wzxWpm5Ont7TVHjx5tPv744+bJJ59sCxDYn/p88YtfNBcsWCD8PZPJmEOHDjW//e1v29/t37/fTKfT5s9//vNCNLHkOPPMM83PfOYz1Hfnn3++uXz5ctM0sU91YQUIlf579dVXTQAwX3jhBbvMxo0bTcMwzJ07dwbWNjRhSDh27Bhs3rwZlixZYn8Xi8VgyZIl8MwzzxSxZaXJgQMHAACgoaEBAAA2b94MfX19VP+OGzcO2tvbsX8lrFq1Cs4880yq3wCwP73wm9/8BmbMmAEXXnghDBkyBKZOnQr/9m//Zv/+9ttvw+7du6k+ra2thdmzZ2OfCpg3bx5s2rQJ3njjDQAAePnll+Hpp5+G008/HQCwT/2i0n/PPPMM1NXVwYwZM+wyS5YsgVgsBs8991xgbRl0m2kFyQcffAADAwPQ0tJCfd/S0gKvv/56kVpVmmQyGVi9ejXMnz8fJk6cCAAAu3fvhlQqBXV1dVTZlpYW2L17dxFaGX3uv/9++POf/wwvvPCC4zfsT33eeustuPvuu+Gmm26CL33pS/DCCy/ADTfcAKlUCi6//HK733hzAPYpnzVr1kBPTw+MGzcO4vE4DAwMwLp162D58uUAANinPlHpv927d8OQIUOo3xOJBDQ0NATaxyhAIAVh1apVsHXrVnj66aeL3ZSSZceOHXDjjTfC448/DmVlZcVuzqAgk8nAjBkz4Jvf/CYAAEydOhW2bt0KP/jBD+Dyyy8vcutKk1/+8pdw7733wn333QcTJkyALVu2wOrVq6G1tRX7dJCBJgwJTU1NEI/HHV7se/bsgaFDhxapVaXH9ddfD4888gg8+eST1FbrQ4cOhWPHjsH+/fup8ti/fDZv3gx79+6FadOmQSKRgEQiAU899RTceeedkEgkoKWlBftTk2HDhsH48eOp70444QR47733AADsfsM5QJ1bbrkF1qxZA5dccglMmjQJPvWpT8HnP/95WL9+PQBgn/pFpf+GDh0Ke/fupX7v7++Hffv2BdrHKEBISKVSMH36dNi0aZP9XSaTgU2bNsHcuXOL2LLSwDRNuP766+Ghhx6CJ554Ajo7O6nfp0+fDslkkurfbdu2wXvvvYf9y2Hx4sXw17/+FbZs2WL/mzFjBixfvtz+G/tTj/nz5ztCi9944w0YOXIkAAB0dnbC0KFDqT7t6emB5557DvtUwKFDhyAWo5eWeDwOmUwGALBP/aLSf3PnzoX9+/fD5s2b7TJPPPEEZDIZmD17dnCNCcwdc5By//33m+l02vzJT35ivvrqq+bVV19t1tXVmbt37y520yLPddddZ9bW1pp/+MMfzPfff9/+d+jQIbvMtddea7a3t5tPPPGE+eKLL5pz5841586dW8RWlxZkFIZpYn/q8vzzz5uJRMJct26d+eabb5r33nuvWVFRYd5zzz12mTvuuMOsq6szf/3rX5t/+ctfzHPOOQdDDiVcfvnl5vDhw+0wzgcffNBsamoyv/CFL9hlsE/l9Pb2mi+99JL50ksvmQBgfve73zVfeukl89133zVNU63/li1bZk6dOtV87rnnzKefftocPXo0hnEWg7vuustsb283U6mUOWvWLPPZZ58tdpNKAgDg/tuwYYNd5vDhw+ZnP/tZs76+3qyoqDDPO+888/333y9eo0sMVoDA/tTnt7/9rTlx4kQznU6b48aNM3/0ox9Rv2cyGfO2224zW1pazHQ6bS5evNjctm1bkVobfXp6eswbb7zRbG9vN8vKysyuri7z1ltvNY8ePWqXwT6V8+STT3Lnzssvv9w0TbX++/DDD81LL73UrKqqMmtqaswrrrjC7O3tDbSduJ03giAIgiDaoA8EgiAIgiDaoACBIAiCIIg2KEAgCIIgCKINChAIgiAIgmiDAgSCIAiCINqgAIEgCIIgiDYoQCAIgiAIog0KEAiCIAiCaIMCBIIgCIIg2qAAgSAIgiCINihAIAiCIAiiDQoQCIIgCIJo8z/PGZ4sXj89TAAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def f_unstable(t, y, args):\n", + " return -(y**3)\n", + "\n", + "\n", + "def g_unstable(t, y, args):\n", + " return jnp.array(10.0, dtype=y.dtype)\n", + "\n", + "\n", + "y0_unstable = jnp.array(0.0)\n", + "\n", + "t1_unstable = 100.0\n", + "\n", + "key_unstable = jr.key(0)\n", + "bm_1d = diffrax.VirtualBrownianTree(\n", + " t0, t1_unstable, 2**-10, (), key_unstable, levy_area=SpaceTimeLevyArea\n", + ")\n", + "\n", + "terms = diffrax.MultiTerm(\n", + " diffrax.ODETerm(f_unstable),\n", + " diffrax.ControlTerm(g_unstable, bm_1d),\n", + ")\n", + "\n", + "dt_unstable = 0.05\n", + "controller = diffrax.PIDController(\n", + " rtol=0,\n", + " atol=10.0,\n", + " pcoeff=0.2,\n", + " icoeff=0.5,\n", + " dcoeff=0,\n", + " dtmin=2**-10,\n", + " dtmax=0.5,\n", + ")\n", + "\n", + "sol_const = diffrax.diffeqsolve(\n", + " terms,\n", + " SPaRK(),\n", + " t0,\n", + " t1_unstable,\n", + " dt_unstable,\n", + " y0_unstable,\n", + " (),\n", + " saveat=diffrax.SaveAt(steps=True),\n", + " max_steps=2**18,\n", + ")\n", + "\n", + "sol_adap = diffrax.diffeqsolve(\n", + " terms,\n", + " SPaRK(),\n", + " t0,\n", + " t1_unstable,\n", + " dt_unstable,\n", + " y0_unstable,\n", + " (),\n", + " saveat=diffrax.SaveAt(steps=True),\n", + " stepsize_controller=controller,\n", + " max_steps=2**18,\n", + ")\n", + "\n", + "print(\n", + " f\"Num steps: constant: {sol_const.stats['num_steps']},\"\n", + " f\" adaptive: {sol_adap.stats['num_steps']}\"\n", + ")\n", + "\n", + "# Plot each solution in a separate figure\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))\n", + "ax1.plot(sol_const.ts, sol_const.ys, label=\"Constant\")\n", + "ax1.set_title(\"Constant step size\")\n", + "ax1.legend()\n", + "\n", + "ax2.plot(sol_adap.ts, sol_adap.ys, label=\"Adaptive\")\n", + "ax2.set_title(\"Adaptive step size\")\n", + "ax2.legend()\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c713ae477ea7b36", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}