From fa67f219bc970cf1e9491df6e46ab3d03968c8f8 Mon Sep 17 00:00:00 2001 From: Kade Heckel Date: Wed, 10 Jan 2024 16:14:19 +0000 Subject: [PATCH] working on spyx v snntorch speed comparison --- research/benchmarking/SHD_comparison.ipynb | 521 +++++++++++++++++++++ 1 file changed, 521 insertions(+) create mode 100644 research/benchmarking/SHD_comparison.ipynb diff --git a/research/benchmarking/SHD_comparison.ipynb b/research/benchmarking/SHD_comparison.ipynb new file mode 100644 index 0000000..57eab91 --- /dev/null +++ b/research/benchmarking/SHD_comparison.ipynb @@ -0,0 +1,521 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "6793a986", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \".40\"\n", + "\n", + "import numpy as np\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import spyx\n", + "import haiku as hk\n", + "import optax\n", + "from jax_tqdm import scan_tqdm\n", + "\n", + "import torch\n", + "import snntorch" + ] + }, + { + "cell_type": "markdown", + "id": "da30d5e6", + "metadata": {}, + "source": [ + "## SHD Dataloading" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ab70fd4b", + "metadata": {}, + "outputs": [], + "source": [ + "shd_dl = spyx.data.SHD_loader(256,128,128)\n", + "\n", + "key = jax.random.PRNGKey(0)\n", + "x, y = shd_dl.train_epoch(key)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "89d88f02", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"mps\") if torch.backends.mps.is_available() else torch.device(\"cpu\")\n", + "j2t_data = lambda data: torch.from_numpy(np.array(jnp.unpackbits(data, axis=1))).to(device)\n", + "j2t_targets = lambda tgt: torch.from_numpy(np.array(tgt)).to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "1f415eb8", + "metadata": {}, + "source": [ + "## Spyx SHD" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ddc797e3", + "metadata": {}, + "outputs": [], + "source": [ + "def shd_snn(x):\n", + " \n", + " core = hk.DeepRNN([\n", + " hk.Linear(64, with_bias=False),\n", + " spyx.nn.LIF((64,), activation=spyx.axn.Axon(spyx.axn.triangular())),\n", + " hk.Linear(64, with_bias=False),\n", + " spyx.nn.LIF((64,), activation=spyx.axn.Axon(spyx.axn.triangular())),\n", + " hk.Linear(20, with_bias=False),\n", + " spyx.nn.LI((20,))\n", + " ])\n", + " \n", + " # static unroll for maximum performance\n", + " spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=32)\n", + " \n", + " return spikes, V" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "77b65a1d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "jaxlib.xla_extension.ArrayImpl" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "344ed98f", + "metadata": {}, + "outputs": [], + "source": [ + "key = jax.random.PRNGKey(0)\n", + "# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!\n", + "SNN = hk.without_apply_rng(hk.transform(shd_snn))\n", + "params = SNN.init(rng=key, x=jnp.float16(x[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9f7422a0", + "metadata": {}, + "outputs": [], + "source": [ + "def gd(SNN, params, dl, epochs=300):\n", + " \n", + " opt = optax.adam(learning_rate=5e-4)\n", + " \n", + " # create and initialize the optimizer\n", + " opt_state = opt.init(params)\n", + " grad_params = params\n", + " \n", + " # define and compile our eval function that computes the loss for our SNN\n", + " @jax.jit\n", + " def net_eval(weights, events, targets):\n", + " readout = SNN.apply(weights, events)\n", + " traces, V_f = readout\n", + " return spyx.fn.integral_crossentropy(traces, targets)\n", + " \n", + " # Use JAX to create a function that calculates the loss and the gradient!\n", + " surrogate_grad = jax.value_and_grad(net_eval) \n", + " \n", + " rng = jax.random.PRNGKey(0) \n", + " \n", + " # compile the meat of our training loop for speed\n", + " @jax.jit\n", + " def train_step(state, data):\n", + " grad_params, opt_state = state\n", + " events, targets = data # fix this\n", + " events = jnp.unpackbits(events, axis=1) # decompress temporal axis\n", + " # compute loss and gradient # need better augment rng\n", + " loss, grads = surrogate_grad(grad_params, events, targets)\n", + " # generate updates based on the gradients and optimizer\n", + " updates, opt_state = opt.update(grads, opt_state, grad_params)\n", + " # return the updated parameters\n", + " new_state = [optax.apply_updates(grad_params, updates), opt_state]\n", + " return new_state, loss\n", + " \n", + " # For validation epochs, do the same as before but compute the\n", + " # accuracy, predictions and losses (no gradients needed)\n", + " @jax.jit\n", + " def eval_step(grad_params, data):\n", + " events, targets = data # fix\n", + " events = jnp.unpackbits(events, axis=1)\n", + " readout = SNN.apply(grad_params, events)\n", + " traces, V_f = readout\n", + " acc, pred = spyx.fn.integral_accuracy(traces, targets)\n", + " loss = spyx.fn.integral_crossentropy(traces, targets)\n", + " return grad_params, jnp.array([acc, loss])\n", + " \n", + " \n", + " val_data = dl.val_epoch()\n", + " \n", + " # Here's the start of our training loop!\n", + " @scan_tqdm(epochs)\n", + " def epoch(epoch_state, epoch_num):\n", + " curr_params, curr_opt_state = epoch_state\n", + " \n", + " shuffle_rng = jax.random.fold_in(rng, epoch_num)\n", + " train_data = dl.train_epoch(shuffle_rng)\n", + " \n", + " # train epoch\n", + " end_state, train_loss = jax.lax.scan(\n", + " train_step,# func\n", + " [curr_params, curr_opt_state],# init\n", + " train_data,# xs\n", + " train_data.obs.shape[0]# len\n", + " )\n", + " \n", + " new_params, _ = end_state\n", + " \n", + " # val epoch\n", + " _, val_metrics = jax.lax.scan(\n", + " eval_step,# func\n", + " new_params,# init\n", + " val_data,# xs\n", + " val_data.obs.shape[0]# len\n", + " )\n", + "\n", + " \n", + " return end_state, jnp.concatenate([jnp.expand_dims(jnp.mean(train_loss),0), jnp.mean(val_metrics, axis=0)])\n", + " # end epoch\n", + " \n", + " # epoch loop\n", + " final_state, metrics = jax.lax.scan(\n", + " epoch,\n", + " [grad_params, opt_state], # metric arrays\n", + " jnp.arange(epochs), # \n", + " epochs # len of loop\n", + " )\n", + " \n", + " final_params, _ = final_state\n", + " \n", + " \n", + " # return our final, optimized network. \n", + " return final_params, metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5a5cb1c4", + "metadata": {}, + "outputs": [], + "source": [ + "def test_gd(SNN, params, dl):\n", + "\n", + " @jax.jit\n", + " def test_step(params, data):\n", + " events, targets = data\n", + " events = jnp.unpackbits(events, axis=1)\n", + " readout = SNN.apply(params, events)\n", + " traces, V_f = readout\n", + " acc, pred = spyx.fn.integral_accuracy(traces, targets)\n", + " loss = spyx.fn.integral_crossentropy(traces, targets)\n", + " return params, [acc, loss, pred, targets]\n", + " \n", + " test_data = dl.test_epoch()\n", + " \n", + " _, test_metrics = jax.lax.scan(\n", + " test_step,# func\n", + " params,# init\n", + " test_data,# xs\n", + " test_data.obs.shape[0]# len\n", + " )\n", + " \n", + " acc = jnp.mean(test_metrics[0])\n", + " loss = jnp.mean(test_metrics[1])\n", + " preds = jnp.array(test_metrics[2]).flatten()\n", + " tgts = jnp.array(test_metrics[3]).flatten()\n", + " return acc, loss, preds, tgts" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "dd08a737", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4bb910e2154b443c9b38409fef6d79e6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/300 [00:00> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_21888/2264808485.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# forward pass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mspk_rec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmem_rec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;31m# initialize the loss & sum over time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1499\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1502\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_21888/1276384039.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mspk2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmem2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlif2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcur2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmem2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mcur3\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspk2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mspk3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmem3\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlif2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcur3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmem3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mV\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmem3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1499\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1502\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/snntorch/_neurons/leaky.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_, mem)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_hidden\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 180\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmem_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 181\u001b[0m \u001b[0mmem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_build_state_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/snntorch/_neurons/neurons.py\u001b[0m in \u001b[0;36mmem_reset\u001b[0;34m(self, mem)\u001b[0m\n\u001b[1;32m 105\u001b[0m Returns reset.\"\"\"\n\u001b[1;32m 106\u001b[0m \u001b[0mmem_shift\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmem\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mthreshold\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 107\u001b[0;31m \u001b[0mreset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspike_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmem_shift\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mreset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 5.80 GiB total capacity; 2.31 GiB already allocated; 367.69 MiB free; 2.90 GiB allowed; 2.89 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF" + ] + } + ], + "source": [ + "num_epochs = 10\n", + "loss_hist = []\n", + "test_loss_hist = []\n", + "counter = 0\n", + "batch_size=256\n", + "num_steps=128\n", + "\n", + "rng = jax.random.PRNGKey(0)\n", + "\n", + "# Outer training loop\n", + "for epoch in range(num_epochs):\n", + " iter_counter = 0\n", + " \n", + " shuffle_rng = jax.random.fold_in(rng, iter_counter)\n", + " train_batch = shd_dl.train_epoch(shuffle_rng)\n", + " train_data, targets = train_batch\n", + " \n", + " \n", + " # Minibatch training loop\n", + " for data, targets in zip(train_data, targets):\n", + "\n", + " data = j2t_data(data).to(dtype=torch.float32)\n", + " targets = j2t_targets(targets)\n", + " # forward pass\n", + " net.train()\n", + " spk_rec, mem_rec = net(data)\n", + "\n", + " # initialize the loss & sum over time\n", + " loss_val = torch.zeros((1), dtype=dtype, device=device)\n", + " for step in range(num_steps):\n", + " loss_val += loss(mem_rec[step], targets)\n", + "\n", + " # Gradient calculation + weight update\n", + " optimizer.zero_grad()\n", + " loss_val.backward()\n", + " optimizer.step()\n", + "\n", + " # Store loss history for future plotting\n", + " loss_hist.append(loss_val.item())\n", + "\n", + " ## Test set\n", + " #with torch.no_grad():\n", + " # net.eval()\n", + " # test_data, test_targets = next(iter(test_loader))\n", + " # test_data = test_data.to(device)\n", + " # test_targets = test_targets.to(device)\n", + "\n", + " # # Test set forward pass\n", + " # test_spk, test_mem = net(test_data.view(batch_size, -1))\n", + "\n", + " # # Test set loss\n", + " # test_loss = torch.zeros((1), dtype=dtype, device=device)\n", + " # for step in range(num_steps):\n", + " # test_loss += loss(test_mem[step], test_targets)\n", + " # test_loss_hist.append(test_loss.item())\n", + "\n", + " # # Print train/test loss/accuracy\n", + " # if counter % 50 == 0:\n", + " # train_printer(\n", + " # data, targets, epoch,\n", + " # counter, iter_counter,\n", + " # loss_hist, test_loss_hist,\n", + " # test_data, test_targets)\n", + " counter += 1\n", + " iter_counter +=1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0bda8291", + "metadata": {}, + "outputs": [], + "source": [ + "train_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ff94871", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}