From deb88b49d7c10385216d55485306a0541225899e Mon Sep 17 00:00:00 2001 From: Kade Heckel Date: Sun, 4 Feb 2024 10:26:49 +0000 Subject: [PATCH] changing spyx.data to optional --- README.md | 10 ++++++---- setup.py | 12 +++++++++++- spyx/experimental.py | 2 ++ spyx/nn.py | 2 +- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 754978a..7a72eb7 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,9 @@ Why use Spyx? ============= -Other frameworks such as SNNTorch and Norse offer a nice range of features such as training with adjoint gradients or support for IPUs in addition to their wonderful tutorials. Spyx is designed to maximize performance by achieving maximum GPU utilization, allowing the training of networks for hundreds of epochs at incredible speed. +Spyx is a compact spiking neural network library built on top of DeepMind's Haiku package, offering the flexibility and extensibility of PyTorch-based frameworks while enabling the extreme perfomance of SNN libraries which implement custom CUDA kernels for their dynamics. -Spyx is a compact spiking neural network library built on top of DeepMind's Haiku package, offering the flexibility and extensibility of PyTorch based frameworks while enabling the extreme perfomance of SNN libraries which implement custom CUDA kernels for their dynamics. The library currently supports training SNNs via surrogate gradient descent and neuroevolution, with additional capabilities such as ANN2SNN conversion, Phasor Networks, and EXODUS being planned for the future. Spyx offers a number of predefined neuron models but is designed for it to be easy to define your own and plug it into a model; the hope is to soon include definitions of SpikingRWKV and other more sophisticated model blocks into the framework. +The library currently supports training SNNs via surrogate gradient descent and neuroevolution, with additional capabilities such as ANN2SNN conversion and Phasor Networks being planned for the future. Spyx offers a number of predefined neuron models but is designed for it to be easy to define your own and plug it into a model; the hope is to soon include definitions of SpikingRWKV and other more sophisticated model blocks into the framework. Installation: ============= @@ -21,18 +21,20 @@ As with other libraries built on top of JAX, you need to install jax with GPU if The best way to install and run Spyx is if you install it into a container/environment that already has JAX and PyTorch installed. +The spyx.data submodule contains some pre-built dataloaders for use with spyx - to install the depedencies for it run the command `pip install spyx[data]` + Hardware Requirements: ====================== Spyx achieves extremely high performance by maintaining the entire dataset in the GPU's vRAM; as such a decent amount of memory for both the CPU and GPU are needed to handle the dataset loading and then training. For smaller networks of only several hundred thousand parameters, the training process can be comfortably executed on even laptop GPU's with only 6GB of vRAM. For large SNNs or for neuroevolution it is recommended to use a higher memory card. -Since Spyx is developed on top of the current JAX version, it does not work on Google Colab's TPUs which use an older version. Cloud TPU support will be tested in the near future. Support for GraphCore's IPU's could be possible based on their fork of JAX but has not been explored. +Since Spyx is developed on top of the current JAX version, it does not work on Google Colab's TPUs which use an older version. Cloud TPU support will be tested in the near future. Contributing: ============= -If you'd like to contribute, head on over to the issues page to find proposed enhancements and leave a comment! +If you'd like to contribute, head on over to the issues page to find proposed enhancements and leave a comment! Also head over to the Open Neuromorphic Discord server to ask questions! Citation: ========= diff --git a/setup.py b/setup.py index 824ffa5..3d75e0e 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,15 @@ ], ) +extras = { + 'data' : [ + 'tonic', + 'torch', + 'torchvision', + 'sklearn' + ] +} + # This call to setup() does all the work setup( name="spyx", @@ -44,5 +53,6 @@ ], packages=["spyx"], include_package_data=True, - install_requires=requires + install_requires=requires, + extras_require=extras ) diff --git a/spyx/experimental.py b/spyx/experimental.py index cb87bb1..f9cf799 100644 --- a/spyx/experimental.py +++ b/spyx/experimental.py @@ -48,6 +48,8 @@ def __init__(self, hidden_shape: tuple, threshold=1, k=10, """ :hidden_shape: Size of hidden layer for the number of Voltages to track. + :threshold: Value for which probability of firing exceeds 50% + :k: The slope of the sigmoid function, the higher the value the closer membrane voltage must to the threshold to have a chance of firing but also a higher chance of continuous firing """ super().__init__(name=name) diff --git a/spyx/nn.py b/spyx/nn.py index cda9337..bf1c26d 100644 --- a/spyx/nn.py +++ b/spyx/nn.py @@ -381,7 +381,7 @@ def __call__(self, x, V): # calculate whether spike is generated, and update membrane potential spikes = self.act(V - self.threshold) - feedback = spikes@recurrent + bias + feedback = spikes@recurrent + bias # this seems like an error... V = V + x + spikes@recurrent - spikes*self.threshold return spikes, V