Skip to content

Commit

Permalink
Merge pull request #481 from ztqakita/master
Browse files Browse the repository at this point in the history
[dyn] add STDP_Song2000 LTP model
  • Loading branch information
chaoming0625 authored Sep 11, 2023
2 parents 7b1faf2 + 36f4585 commit 4585f20
Show file tree
Hide file tree
Showing 9 changed files with 460 additions and 26 deletions.
4 changes: 2 additions & 2 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def retrieve(self, delay_step, *indices):

if self.method == ROTATE_UPDATE:
i = share.load('i')
delay_idx = bm.as_jax((delay_step - i - 1) % self.max_length)
delay_idx = bm.as_jax((delay_step - i - 1) % self.max_length, dtype=jnp.int32)
delay_idx = jax.lax.stop_gradient(delay_idx)

elif self.method == CONCAT_UPDATE:
Expand Down Expand Up @@ -358,7 +358,7 @@ def update(
# update the delay data at the rotation index
if self.method == ROTATE_UPDATE:
i = share.load('i')
idx = bm.as_jax((-i - 1) % self.max_length)
idx = bm.as_jax((-i - 1) % self.max_length, dtype=jnp.int32)
self.data[idx] = latest_value

# update the delay data at the first position
Expand Down
115 changes: 103 additions & 12 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.types import ArrayType, Sharding
from brainpy._src.dnn.base import Layer
from brainpy._src.mixin import SupportOnline, SupportOffline
from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP

__all__ = [
'Dense', 'Linear',
Expand All @@ -29,22 +29,22 @@
]


class Dense(Layer, SupportOnline, SupportOffline):
class Dense(Layer, SupportOnline, SupportOffline, SupportSTDP):
r"""A linear transformation applied over the last dimension of the input.
Mathematically, this node can be defined as:
.. math::
y = x \cdot W + b
y = x \cdot weight + b
Parameters
----------
num_in: int
The number of the input feature. A positive integer.
num_out: int
The number of the output features. A positive integer.
W_initializer: optional, Initializer
weight_initializer: optional, Initializer
The weight initialization.
b_initializer: optional, Initializer
The bias initialization.
Expand Down Expand Up @@ -74,13 +74,13 @@ def __init__(
f'a positive integer. Received: num_out={num_out}')

# weight initializer
self.weight_initializer = W_initializer
self.W_initializer = W_initializer
self.bias_initializer = b_initializer
is_initializer(W_initializer, 'weight_initializer')
is_initializer(b_initializer, 'bias_initializer', allow_none=True)

# parameter initialization
W = parameter(self.weight_initializer, (num_in, self.num_out))
W = parameter(self.W_initializer, (num_in, self.num_out))
b = parameter(self.bias_initializer, (self.num_out,))
if isinstance(self.mode, bm.TrainingMode):
W = bm.TrainVar(W)
Expand Down Expand Up @@ -198,6 +198,20 @@ def offline_fit(self,
self.W.value = Wff
self.b.value = bias[0]

def update_STDP(self, dW, constraints=None):
if isinstance(self.W, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
if self.W.shape != dW.shape:
raise ValueError(f'The shape of delta_weight {dW.shape} '
f'should be the same as the shape of weight {self.W.shape}.')
if not isinstance(self.W, bm.Variable):
self.tracing_variable('W', self.W, self.W.shape)
self.W += dW
if constraints is not None:
self.W.value = constraints(self.W)


Linear = Dense

Expand All @@ -213,7 +227,7 @@ def update(self, x):
return x


class AllToAll(Layer):
class AllToAll(Layer, SupportSTDP):
"""Synaptic matrix multiplication with All2All connections.
Args:
Expand Down Expand Up @@ -275,8 +289,23 @@ def update(self, pre_val):
post_val = pre_val @ self.weight
return post_val

def update_STDP(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
if self.weight.shape != dW.shape:
raise ValueError(f'The shape of delta_weight {dW.shape} '
f'should be the same as the shape of weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += dW
if constraints is not None:
self.weight.value = constraints(self.weight)


class OneToOne(Layer):

class OneToOne(Layer, SupportSTDP):
"""Synaptic matrix multiplication with One2One connection.
Args:
Expand Down Expand Up @@ -309,8 +338,23 @@ def __init__(
def update(self, pre_val):
return pre_val * self.weight


class MaskedLinear(Layer):
def update_STDP(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
dW = dW.sum(axis=0)
if self.weight.shape != dW.shape:
raise ValueError(f'The shape of delta_weight {dW.shape} '
f'should be the same as the shape of weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += dW
if constraints is not None:
self.weight.value = constraints(self.weight)


class MaskedLinear(Layer, SupportSTDP):
r"""Synaptic matrix multiplication with masked dense computation.
It performs the computation of:
Expand Down Expand Up @@ -363,8 +407,23 @@ def __init__(
def update(self, x):
return x @ self.mask_fun(self.weight * self.mask)

def update_STDP(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
if self.weight.shape != dW.shape:
raise ValueError(f'The shape of delta_weight {dW.shape} '
f'should be the same as the shape of weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)

class CSRLinear(Layer):
self.weight += dW
if constraints is not None:
self.weight.value = constraints(self.weight)


class CSRLinear(Layer, SupportSTDP):
r"""Synaptic matrix multiplication with CSR sparse computation.
It performs the computation of:
Expand Down Expand Up @@ -432,6 +491,22 @@ def _batch_csrmv(self, x):
transpose=self.transpose,
method=self.method)

def update_STDP(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
pre_ids, post_ids = bm.sparse.csr_to_coo(self.indices, self.indptr)
sparse_dW = dW[pre_ids, post_ids]
if self.weight.shape != sparse_dW.shape:
raise ValueError(f'The shape of sparse delta_weight {sparse_dW.shape} '
f'should be the same as the shape of sparse weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += sparse_dW
if constraints is not None:
self.weight.value = constraints(self.weight)


class CSCLinear(Layer):
r"""Synaptic matrix multiplication with CSC sparse computation.
Expand Down Expand Up @@ -468,7 +543,7 @@ def __init__(
self.sharding = sharding


class EventCSRLinear(Layer):
class EventCSRLinear(Layer, SupportSTDP):
r"""Synaptic matrix multiplication with event CSR sparse computation.
It performs the computation of:
Expand Down Expand Up @@ -532,6 +607,22 @@ def _batch_csrmv(self, x):
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose)

def update_STDP(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
pre_ids, post_ids = bm.sparse.csr_to_coo(self.indices, self.indptr)
sparse_dW = dW[pre_ids, post_ids]
if self.weight.shape != sparse_dW.shape:
raise ValueError(f'The shape of sparse delta_weight {sparse_dW.shape} '
f'should be the same as the shape of sparse weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += sparse_dW
if constraints is not None:
self.weight.value = constraints(self.weight)


class BcsrMM(Layer):
r"""Synaptic matrix multiplication with BCSR sparse computation.
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/projections/aligns.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,4 +1052,4 @@ def update(self):
spk = self.refs['delay'].at(self.name)
g = self.comm(self.syn(spk))
self.refs['out'].bind_cond(g)
return g
return g
Loading

0 comments on commit 4585f20

Please sign in to comment.