Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[reset] update logics of state reset in DynamicalSystem #501

Merged
merged 2 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.4.post4"
_minimal_brainpylib_version = '0.1.10'

# fundamental supporting modules
from brainpy import errors, check, tools
Expand All @@ -11,6 +12,15 @@
except ModuleNotFoundError:
raise ModuleNotFoundError(tools.jaxlib_install_info) from None


try:
import brainpylib
if brainpylib.__version__ < _minimal_brainpylib_version:
raise SystemError(f'This version of brainpy ({__version__}) needs brainpylib >= {_minimal_brainpylib_version}.')
del brainpylib
except ModuleNotFoundError:
pass

# Part: Math Foundation #
# ----------------------- #

Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/analysis/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ def update(self):
self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt']))

def __getattr__(self, item):
child_vars = super(TrajectModel, self).__getattribute__('implicit_vars')
child_vars = super().__getattribute__('implicit_vars')
if item in child_vars:
return child_vars[item]
else:
return super(TrajectModel, self).__getattribute__(item)
return super().__getattribute__(item)

def run(self, duration):
self.runner.run(duration)
Expand Down
11 changes: 5 additions & 6 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,6 @@ def reset_state(self, batch_size: int = None):
# initialize delay data
if self.data is not None:
self._init_data(self.max_length, batch_size)
for cls in self.before_updates.values():
cls.reset_state(batch_size)
for cls in self.after_updates.values():
cls.reset_state(batch_size)

def _init_data(self, length: int, batch_size: int = None):
if batch_size is not None:
Expand Down Expand Up @@ -468,13 +464,16 @@ def __init__(
*indices
):
super().__init__(mode=delay.mode)
self.delay = delay
self.refs = {'delay': delay}
assert isinstance(delay, Delay)
delay.register_entry(self.name, time)
self.indices = indices

def update(self):
return self.delay.at(self.name, *self.indices)
return self.refs['delay'].at(self.name, *self.indices)

def reset_state(self, *args, **kwargs):
pass


def init_delay_by_return(info: Union[bm.Variable, ReturnInfo]) -> Delay:
Expand Down
12 changes: 10 additions & 2 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,14 @@ def has_aft_update(self, key: Any):
def reset_bef_updates(self, *args, **kwargs):
"""Reset all before updates."""
for node in self.before_updates.values():
node.reset_state(*args, **kwargs)
if isinstance(node, DynamicalSystem):
node.reset(*args, **kwargs)

def reset_aft_updates(self, *args, **kwargs):
"""Reset all after updates."""
for node in self.after_updates.values():
node.reset_state(*args, **kwargs)
if isinstance(node, DynamicalSystem):
node.reset(*args, **kwargs)

def update(self, *args, **kwargs):
"""The function to specify the updating rule.
Expand Down Expand Up @@ -349,6 +351,12 @@ def _compatible_update(self, *args, **kwargs):
return ret
return update_fun(*args, **kwargs)

# def __getattr__(self, item):
# if item == 'update':
# return self._compatible_update # update function compatible with previous ``update()`` function
# else:
# return object.__getattribute__(self, item)

def __getattribute__(self, item):
if item == 'update':
return self._compatible_update # update function compatible with previous ``update()`` function
Expand Down
1 change: 1 addition & 0 deletions brainpy/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
coo2mat_num as coo2mat_num,
mat2mat_num as mat2mat_num,
visualizeMat as visualizeMat,
set_default_dtype as set_default_dtype,

CONN_MAT,
PRE_IDS, POST_IDS,
Expand Down