-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
1,240 additions
and
674 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Spyx\n", | ||
"\n", | ||
"Spyx is a JAX-based SNN/Deep learning framework that enables fully JIT compiled optimization of models." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import spyx\n", | ||
"import spyx.nn as snn\n", | ||
"\n", | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"\n", | ||
"import nir" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Import a NIR graph to Spyx:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the NIR graph from disk\n", | ||
"nir_graph = nir.read(\"saved_network.nir\")\n", | ||
"\n", | ||
"# Use the nir_graph and a sample of your input (for shape information)\n", | ||
"# dt is used to scale the weights properly if the imported network was trained\n", | ||
"# in a different simulator where dt is not necessarily 1.\n", | ||
"SNN, params = spyx.nir.from_nir(nir_graph, sample_batch, dt=1)\n", | ||
"\n", | ||
"# Use it as you wish:\n", | ||
"SNN.apply(params, sample_batch)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Export a network from Spyx to a NIR graph:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Some operations may have rearranged the PyTree (dictionary) that stores\n", | ||
"# the SNN weights, so the helper function reorders the dict\n", | ||
"# to allow for proper exportation. \n", | ||
"export_params = spyx.nir.reorder_layers(init_params, optimized_params)\n", | ||
"\n", | ||
"# provide the params to export along with the input/output sizes and the desired\n", | ||
"# time resolution; this is so you can load it up with the proper dt in other\n", | ||
"# frameworks that allow you to specify smaller time intervals\n", | ||
"# whereas Spyx assumes every timestep to be 1 to avoid units.\n", | ||
"nir_graph = spyx.nir.to_nir(export_params, input_shape, output_shape, dt)\n", | ||
"\n", | ||
"# Write the NIR graph to the desired filepath\n", | ||
"nir.write(\"./spyx_shd.nir\", nir_graph)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.