Skip to content

Commit

Permalink
fixed neuron models
Browse files Browse the repository at this point in the history
  • Loading branch information
kmheckel committed Jan 25, 2024
1 parent ced2ae7 commit 00e6ba4
Show file tree
Hide file tree
Showing 27 changed files with 1,240 additions and 674 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
.ipynb_checkpoints
*/.ipynb_checkpoints/*

# datasets
.h5
.zip

# IPython
profile_default/
ipython_config.py
Expand Down
88 changes: 88 additions & 0 deletions docs/examples/nir/conversion.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Spyx\n",
"\n",
"Spyx is a JAX-based SNN/Deep learning framework that enables fully JIT compiled optimization of models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import spyx\n",
"import spyx.nn as snn\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import nir"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import a NIR graph to Spyx:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the NIR graph from disk\n",
"nir_graph = nir.read(\"saved_network.nir\")\n",
"\n",
"# Use the nir_graph and a sample of your input (for shape information)\n",
"# dt is used to scale the weights properly if the imported network was trained\n",
"# in a different simulator where dt is not necessarily 1.\n",
"SNN, params = spyx.nir.from_nir(nir_graph, sample_batch, dt=1)\n",
"\n",
"# Use it as you wish:\n",
"SNN.apply(params, sample_batch)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export a network from Spyx to a NIR graph:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Some operations may have rearranged the PyTree (dictionary) that stores\n",
"# the SNN weights, so the helper function reorders the dict\n",
"# to allow for proper exportation. \n",
"export_params = spyx.nir.reorder_layers(init_params, optimized_params)\n",
"\n",
"# provide the params to export along with the input/output sizes and the desired\n",
"# time resolution; this is so you can load it up with the proper dt in other\n",
"# frameworks that allow you to specify smaller time intervals\n",
"# whereas Spyx assumes every timestep to be 1 to avoid units.\n",
"nir_graph = spyx.nir.to_nir(export_params, input_shape, output_shape, dt)\n",
"\n",
"# Write the NIR graph to the desired filepath\n",
"nir.write(\"./spyx_shd.nir\", nir_graph)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 00e6ba4

Please sign in to comment.