Skip to content

Commit

Permalink
changing spyx.data to optional
Browse files Browse the repository at this point in the history
  • Loading branch information
kmheckel committed Feb 4, 2024
1 parent 0d1c613 commit deb88b4
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
Why use Spyx?
=============

Other frameworks such as SNNTorch and Norse offer a nice range of features such as training with adjoint gradients or support for IPUs in addition to their wonderful tutorials. Spyx is designed to maximize performance by achieving maximum GPU utilization, allowing the training of networks for hundreds of epochs at incredible speed.
Spyx is a compact spiking neural network library built on top of DeepMind's Haiku package, offering the flexibility and extensibility of PyTorch-based frameworks while enabling the extreme perfomance of SNN libraries which implement custom CUDA kernels for their dynamics.

Spyx is a compact spiking neural network library built on top of DeepMind's Haiku package, offering the flexibility and extensibility of PyTorch based frameworks while enabling the extreme perfomance of SNN libraries which implement custom CUDA kernels for their dynamics. The library currently supports training SNNs via surrogate gradient descent and neuroevolution, with additional capabilities such as ANN2SNN conversion, Phasor Networks, and EXODUS being planned for the future. Spyx offers a number of predefined neuron models but is designed for it to be easy to define your own and plug it into a model; the hope is to soon include definitions of SpikingRWKV and other more sophisticated model blocks into the framework.
The library currently supports training SNNs via surrogate gradient descent and neuroevolution, with additional capabilities such as ANN2SNN conversion and Phasor Networks being planned for the future. Spyx offers a number of predefined neuron models but is designed for it to be easy to define your own and plug it into a model; the hope is to soon include definitions of SpikingRWKV and other more sophisticated model blocks into the framework.

Installation:
=============
Expand All @@ -21,18 +21,20 @@ As with other libraries built on top of JAX, you need to install jax with GPU if

The best way to install and run Spyx is if you install it into a container/environment that already has JAX and PyTorch installed.

The spyx.data submodule contains some pre-built dataloaders for use with spyx - to install the depedencies for it run the command `pip install spyx[data]`

Hardware Requirements:
======================

Spyx achieves extremely high performance by maintaining the entire dataset in the GPU's vRAM; as such a decent amount of memory for both the CPU and GPU are needed to handle the dataset loading and then training. For smaller networks of only several hundred thousand parameters, the training process can be comfortably executed on even laptop GPU's with only 6GB of vRAM. For large SNNs or for neuroevolution it is recommended to use a higher memory card.

Since Spyx is developed on top of the current JAX version, it does not work on Google Colab's TPUs which use an older version. Cloud TPU support will be tested in the near future. Support for GraphCore's IPU's could be possible based on their fork of JAX but has not been explored.
Since Spyx is developed on top of the current JAX version, it does not work on Google Colab's TPUs which use an older version. Cloud TPU support will be tested in the near future.


Contributing:
=============

If you'd like to contribute, head on over to the issues page to find proposed enhancements and leave a comment!
If you'd like to contribute, head on over to the issues page to find proposed enhancements and leave a comment! Also head over to the Open Neuromorphic Discord server to ask questions!

Citation:
=========
Expand Down
12 changes: 11 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
],
)

extras = {
'data' : [
'tonic',
'torch',
'torchvision',
'sklearn'
]
}

# This call to setup() does all the work
setup(
name="spyx",
Expand All @@ -44,5 +53,6 @@
],
packages=["spyx"],
include_package_data=True,
install_requires=requires
install_requires=requires,
extras_require=extras
)
2 changes: 2 additions & 0 deletions spyx/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(self, hidden_shape: tuple, threshold=1, k=10,
"""
:hidden_shape: Size of hidden layer for the number of Voltages to track.
:threshold: Value for which probability of firing exceeds 50%
:k: The slope of the sigmoid function, the higher the value the closer membrane voltage must to the threshold to have a chance of firing but also a higher chance of continuous firing
"""
super().__init__(name=name)
Expand Down
2 changes: 1 addition & 1 deletion spyx/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def __call__(self, x, V):

# calculate whether spike is generated, and update membrane potential
spikes = self.act(V - self.threshold)
feedback = spikes@recurrent + bias
feedback = spikes@recurrent + bias # this seems like an error...
V = V + x + spikes@recurrent - spikes*self.threshold

return spikes, V
Expand Down

0 comments on commit deb88b4

Please sign in to comment.