From 862ac2073c83462f040db26770ce297904c17ee2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 5 Dec 2024 10:29:51 +0000 Subject: [PATCH] Deployed 79f311c to dev with MkDocs 1.6.1 and mike 2.1.3 --- dev/reference/modules/index.html | 104 +++++++++++++++---------------- dev/search/search_index.json | 2 +- 2 files changed, 50 insertions(+), 56 deletions(-) diff --git a/dev/reference/modules/index.html b/dev/reference/modules/index.html index 71b91a0d..9511184e 100644 --- a/dev/reference/modules/index.html +++ b/dev/reference/modules/index.html @@ -14432,8 +14432,7 @@

572 573 574 -575 -576
class Network(Module):
+575
class Network(Module):
     """Network class.
 
     This class defines a network of cells that can be connected with synapses.
@@ -14830,9 +14829,10 @@ 

color: str = "k", synapse_color: str = "b", dims: Tuple[int] = (0, 1), - type: str = "line", cell_plot_kwargs: Dict = {}, synapse_plot_kwargs: Dict = {}, + synapse_scatter_kwargs: Dict = {}, + **kwargs, # absorb add. kwargs, i.e. to enable net.cell(0).vis(type="line") ) -> Axes: """Visualize the module. @@ -14840,19 +14840,17 @@

detail: Either of [point, full]. `point` visualizes every neuron in the network as a dot. `full` plots the full morphology of every neuron. It requires that - `compute_xyz()` has been run and allows for indivual neurons to be - moved with `.move()`. - color: The color in which cells are plotted. Only takes effect if - `detail='full'`. - type: Either `line` or `scatter`. Only takes effect if `detail='full'`. - synapse_color: The color in which synapses are plotted. Only takes effect if - `detail='full'`. + `compute_xyz()` has been run. + color: The color in which cells are plotted. + synapse_color: The color in which synapses are plotted. dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them. cell_plot_kwargs: Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for `detail='full'`. synapse_plot_kwargs: Keyword arguments passed to the plotting function for - syanpses. Only takes effect for `detail='full'`. + syanpses. + synapse_scatter_kwargs: Keyword arguments passed to the scatter function for + syanpse terminals. """ xyz0 = self.cell(0).xyzr[0][:, :3] same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells]) @@ -14874,9 +14872,7 @@

pos = cell_to_point_xyz(cell)[dims_np] ax.scatter(*pos, color=color, **cell_plot_kwargs) elif detail == "full": - ax = super().vis( - dims=dims, color=color, ax=ax, type=type, **cell_plot_kwargs - ) + ax = super().vis(dims=dims, color=color, ax=ax, **cell_plot_kwargs) else: raise ValueError("detail must be in {full, point}.") @@ -14887,7 +14883,7 @@

loc, comp = edge[[prepost + "_locs", prepost + "_global_comp_index"]] branch = nodes.loc[comp, "global_branch_index"] cell = nodes.loc[comp, "global_cell_index"] - branch_xyz = self.xyzr[branch] + branch_xyz = self.xyzr[branch][:, :3] xyz_loc = branch_xyz if detail == "point": @@ -14903,8 +14899,10 @@

prepost_locs.append(xyz_loc) prepost_locs = np.stack(prepost_locs).T - ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs) + ax.scatter( + *prepost_locs[dims_np, 1], color=synapse_color, **synapse_scatter_kwargs + ) return ax @@ -15358,7 +15356,7 @@

- vis(detail='full', ax=None, color='k', synapse_color='b', dims=(0, 1), type='line', cell_plot_kwargs={}, synapse_plot_kwargs={}) + vis(detail='full', ax=None, color='k', synapse_color='b', dims=(0, 1), cell_plot_kwargs={}, synapse_plot_kwargs={}, synapse_scatter_kwargs={}, **kwargs)

@@ -15391,8 +15389,7 @@

Either of [point, full]. point visualizes every neuron in the network as a dot. full plots the full morphology of every neuron. It requires that -compute_xyz() has been run and allows for indivual neurons to be -moved with .move().

+compute_xyz() has been run.

@@ -15408,8 +15405,7 @@

-

The color in which cells are plotted. Only takes effect if -detail='full'.

+

The color in which cells are plotted.

@@ -15418,57 +15414,57 @@

- type + synapse_color str
-

Either line or scatter. Only takes effect if detail='full'.

+

The color in which synapses are plotted.

- 'line' + 'b' - synapse_color + dims - str + Tuple[int]
-

The color in which synapses are plotted. Only takes effect if -detail='full'.

+

Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of +two of them.

- 'b' + (0, 1) - dims + cell_plot_kwargs - Tuple[int] + Dict
-

Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of -two of them.

+

Keyword arguments passed to the plotting function for +cell morphologies. Only takes effect for detail='full'.

- (0, 1) + {} - cell_plot_kwargs + synapse_plot_kwargs Dict @@ -15476,7 +15472,7 @@

Keyword arguments passed to the plotting function for -cell morphologies. Only takes effect for detail='full'.

+syanpses.

@@ -15485,15 +15481,15 @@

- synapse_plot_kwargs + synapse_scatter_kwargs Dict
-

Keyword arguments passed to the plotting function for -syanpses. Only takes effect for detail='full'.

+

Keyword arguments passed to the scatter function for +syanpse terminals.

@@ -15587,17 +15583,17 @@

503 504 505 -506 -507
def vis(
+506
def vis(
     self,
     detail: str = "full",
     ax: Optional[Axes] = None,
     color: str = "k",
     synapse_color: str = "b",
     dims: Tuple[int] = (0, 1),
-    type: str = "line",
     cell_plot_kwargs: Dict = {},
     synapse_plot_kwargs: Dict = {},
+    synapse_scatter_kwargs: Dict = {},
+    **kwargs,  # absorb add. kwargs, i.e. to enable net.cell(0).vis(type="line")
 ) -> Axes:
     """Visualize the module.
 
@@ -15605,19 +15601,17 @@ 

detail: Either of [point, full]. `point` visualizes every neuron in the network as a dot. `full` plots the full morphology of every neuron. It requires that - `compute_xyz()` has been run and allows for indivual neurons to be - moved with `.move()`. - color: The color in which cells are plotted. Only takes effect if - `detail='full'`. - type: Either `line` or `scatter`. Only takes effect if `detail='full'`. - synapse_color: The color in which synapses are plotted. Only takes effect if - `detail='full'`. + `compute_xyz()` has been run. + color: The color in which cells are plotted. + synapse_color: The color in which synapses are plotted. dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them. cell_plot_kwargs: Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for `detail='full'`. synapse_plot_kwargs: Keyword arguments passed to the plotting function for - syanpses. Only takes effect for `detail='full'`. + syanpses. + synapse_scatter_kwargs: Keyword arguments passed to the scatter function for + syanpse terminals. """ xyz0 = self.cell(0).xyzr[0][:, :3] same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells]) @@ -15639,9 +15633,7 @@

pos = cell_to_point_xyz(cell)[dims_np] ax.scatter(*pos, color=color, **cell_plot_kwargs) elif detail == "full": - ax = super().vis( - dims=dims, color=color, ax=ax, type=type, **cell_plot_kwargs - ) + ax = super().vis(dims=dims, color=color, ax=ax, **cell_plot_kwargs) else: raise ValueError("detail must be in {full, point}.") @@ -15652,7 +15644,7 @@

loc, comp = edge[[prepost + "_locs", prepost + "_global_comp_index"]] branch = nodes.loc[comp, "global_branch_index"] cell = nodes.loc[comp, "global_cell_index"] - branch_xyz = self.xyzr[branch] + branch_xyz = self.xyzr[branch][:, :3] xyz_loc = branch_xyz if detail == "point": @@ -15668,8 +15660,10 @@

prepost_locs.append(xyz_loc) prepost_locs = np.stack(prepost_locs).T - ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs) + ax.scatter( + *prepost_locs[dims_np, 1], color=synapse_color, **synapse_scatter_kwargs + ) return ax

diff --git a/dev/search/search_index.json b/dev/search/search_index.json index 57fecd83..04119912 100644 --- a/dev/search/search_index.json +++ b/dev/search/search_index.json @@ -1 +1 @@ -{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Home","text":"

The official documentation for Jaxley has moved to jaxley.readthedocs.io. The website you are currently on will be taken down in the future.

Jaxley is a differentiable simulator for biophysical neuron models in JAX. Its key features are:

  • automatic differentiation, allowing gradient-based optimization of thousands of parameters
  • support for CPU, GPU, or TPU without any changes to the code
  • jit-compilation, making it as fast as other packages while being fully written in python
  • backward-Euler solver for stable numerical solution of multicompartment neurons
  • elegant mechanisms for parameter sharing
"},{"location":"#getting-started","title":"Getting started","text":"

Jaxley allows to simulate biophysical neuron models on CPU, GPU, or TPU:

import matplotlib.pyplot as plt\nfrom jax import config\n\nimport jaxley as jx\nfrom jaxley.channels import HH\n\nconfig.update(\"jax_platform_name\", \"cpu\")  # Or \"gpu\" / \"tpu\".\n\ncell = jx.Cell()  # Define cell.\ncell.insert(HH())  # Insert channels.\n\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=10.0)\ncell.stimulate(current)  # Stimulate with step current.\ncell.record(\"v\")  # Record voltage.\n\nv = jx.integrate(cell)  # Run simulation.\nplt.plot(v.T)  # Plot voltage trace.\n

If you want to learn more, we have tutorials on how to:

  • simulate morphologically detailed neurons
  • simulate networks of such neurons
  • set parameters of cells and networks
  • speed up simulations with GPUs and jit
  • define your own channels and synapses
  • define groups
  • read and handle SWC files
  • compute the gradient and train biophysical models
"},{"location":"#installation","title":"Installation","text":"

Jaxley is available on pypi:

pip install jaxley\n
This will install Jaxley with CPU support. If you want GPU support, follow the instructions on the JAX github repository to install JAX with GPU support (in addition to installing Jaxley). For example, for NVIDIA GPUs, run
pip install -U \"jax[cuda12]\"\n

"},{"location":"#feedback-and-contributions","title":"Feedback and Contributions","text":"

We welcome any feedback on how Jaxley is working for your neuron models and are happy to receive bug reports, pull requests and other feedback (see contribute). We wish to maintain a positive community, please read our Code of Conduct.

"},{"location":"#license","title":"License","text":"

Apache License Version 2.0 (Apache-2.0)

"},{"location":"#citation","title":"Citation","text":"

If you use Jaxley, consider citing the corresponding paper:

@article{deistler2024differentiable,\n  doi = {10.1101/2024.08.21.608979},\n  year = {2024},\n  publisher = {Cold Spring Harbor Laboratory},\n  author = {Deistler, Michael and Kadhim, Kyra L. and Pals, Matthijs and Beck, Jonas and Huang, Ziwei and Gloeckler, Manuel and Lappalainen, Janne K. and Schr{\\\"o}der, Cornelius and Berens, Philipp and Gon{\\c c}alves, Pedro J. and Macke, Jakob H.},\n  title = {Differentiable simulation enables large-scale training of detailed biophysical models of neural dynamics},\n  journal = {bioRxiv}\n}\n
"},{"location":"code_of_conduct/","title":"Contributor Covenant Code of Conduct","text":""},{"location":"code_of_conduct/#our-pledge","title":"Our Pledge","text":"

We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation.

We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.

"},{"location":"code_of_conduct/#our-standards","title":"Our Standards","text":"

Examples of behavior that contributes to a positive environment for our community include:

  • Demonstrating empathy and kindness toward other people
  • Being respectful of differing opinions, viewpoints, and experiences
  • Giving and gracefully accepting constructive feedback
  • Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
  • Focusing on what is best not just for us as individuals, but for the overall community

Examples of unacceptable behavior include:

  • The use of sexualized language or imagery, and sexual attention or advances of any kind
  • Trolling, insulting or derogatory comments, and personal or political attacks
  • Public or private harassment
  • Publishing others\u2019 private information, such as a physical or email address, without their explicit permission
  • Other conduct which could reasonably be considered inappropriate in a professional setting
"},{"location":"code_of_conduct/#enforcement-responsibilities","title":"Enforcement Responsibilities","text":"

Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.

Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.

"},{"location":"code_of_conduct/#scope","title":"Scope","text":"

This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.

"},{"location":"code_of_conduct/#enforcement","title":"Enforcement","text":"

Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting jaxley developer Michael Deistler via email (michael.deistler@uni-tuebingen.de). All complaints will be reviewed and investigated promptly and fairly.

All community leaders are obligated to respect the privacy and security of the reporter of any incident.

"},{"location":"code_of_conduct/#enforcement-guidelines","title":"Enforcement Guidelines","text":"

Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:

"},{"location":"code_of_conduct/#1-correction","title":"1. Correction","text":"

Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.

Consequence: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.

"},{"location":"code_of_conduct/#2-warning","title":"2. Warning","text":"

Community Impact: A violation through a single incident or series of actions.

Consequence: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.

"},{"location":"code_of_conduct/#3-temporary-ban","title":"3. Temporary Ban","text":"

Community Impact: A serious violation of community standards, including sustained inappropriate behavior.

Consequence: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.

"},{"location":"code_of_conduct/#4-permanent-ban","title":"4. Permanent Ban","text":"

Community Impact: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.

Consequence: A permanent ban from any sort of public interaction within the community.

"},{"location":"code_of_conduct/#attribution","title":"Attribution","text":"

This Code of Conduct is adapted from the Contributor Covenant, version 2.1, available at https://www.contributor-covenant.org/version/2/1/code_of_conduct.html.

Community Impact Guidelines were inspired by Mozilla\u2019s code of conduct enforcement ladder.

For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.

"},{"location":"contribute/","title":"Guide","text":""},{"location":"contribute/#user-experiences-bugs-and-feature-requests","title":"User experiences, bugs, and feature requests","text":"

To report bugs and suggest features (including better documentation), please head over to issues on GitHub.

"},{"location":"contribute/#code-contributions","title":"Code contributions","text":"

In general, we use pull requests to make changes to Jaxley. So, if you are planning to make a contribution, please fork, create a feature branch and then make a PR from your feature branch to the upstream Jaxley (details).

"},{"location":"contribute/#development-environment","title":"Development environment","text":"

Clone the repo and install via setup.py using pip install -e \".[dev]\" (the dev flag installs development and testing dependencies).

"},{"location":"contribute/#style-conventions","title":"Style conventions","text":"

For docstrings and comments, we use Google Style.

Code needs to pass through the following tools, which are installed alongside Jaxley:

black: Automatic code formatting for Python. You can run black manually from the console using black . in the top directory of the repository, which will format all files.

isort: Used to consistently order imports. You can run isort manually from the console using isort in the top directory.

black and isort are checked as part of our CI actions. If these checks fail please make sure you have installed the latest versions for each of them and run them locally.

"},{"location":"contribute/#online-documentation","title":"Online documentation","text":"

Most of the documentation is written in markdown (basic markdown guide).

You can directly fix mistakes and suggest clearer formulations in markdown files simply by initiating a PR on through GitHub. Click on documentation file and look for the little pencil at top right.

"},{"location":"credits/","title":"Credits","text":"

Jaxley is a collaborative project between the groups of Jakob Macke (Uni T\u00fcbingen), Pedro Gon\u00e7alves (KU Leuven / NERF), and Philipp Berens (Uni T\u00fcbingen).

"},{"location":"credits/#license","title":"License","text":"

Jaxley is licensed under the Apache License Version 2.0 (Apache-2.0) and

Copyright (C) 2024 Michael Deistler, Jakob H. Macke, Pedro J. Goncalves, Philipp Berens.

"},{"location":"credits/#important-dependencies-and-prior-art","title":"Important dependencies and prior art","text":"
  • We greatly benefited from previous toolboxes for simulating multicompartment neurons, in particular NEURON.
"},{"location":"credits/#funding","title":"Funding","text":"

This work was supported by the German Research Foundation (DFG) through Germany\u2019s Excellence Strategy (EXC 2064 \u2013 Project number 390727645) and the CRC 1233 \u201cRobust Vision\u201d, the German Federal Ministry of Education and Research (Tu\u0308bingen AI Center, FKZ: 01IS18039A), the \u2018Certification and Foundations of Safe Machine Learning Systems in Healthcare\u2019 project funded by the Carl Zeiss Foundation, and the European Union (ERC, \u201cDeepCoMechTome\u201d, ref. 101089288, \u201cNextMechMod\u201d, ref. 101039115).

"},{"location":"faq/","title":"Frequently asked questions","text":"
  • What kinds of models can be implemented in Jaxley?
  • What units does Jaxley use?
  • How can I save and load cells and networks?

See also the discussion page and the issue tracker on the Jaxley GitHub repository for recent questions and problems.

"},{"location":"install/","title":"Installation","text":""},{"location":"install/#install-the-most-recent-stable-version","title":"Install the most recent stable version","text":"

Jaxley is available on PyPI:

pip install jaxley\n
This will install Jaxley with CPU support. If you want GPU support, follow the instructions on the JAX github repository to install JAX with GPU support (in addition to installing Jaxley). For example, for NVIDIA GPUs, run
pip install -U \"jax[cuda12]\"\n

"},{"location":"install/#install-from-source","title":"Install from source","text":"

You can also install Jaxley from source:

git clone https://github.com/jaxleyverse/jaxley.git\ncd jaxley\npip install -e .\n

Note that pip>=21.3 is required to install the editable version with pyproject.toml see pip docs.

"},{"location":"faq/question_01/","title":"What units does Jaxley use?","text":"

Jaxley uses the same units as the NEURON simulator, which are listed here.

"},{"location":"faq/question_02/","title":"How can I save and load cells and networks?","text":"

All modules (i.e., compartments, branches, cells, and networks) in Jaxley can be saved and loaded with pickle:

import jaxley as jx\nimport pickle\n\n# ... define network, cell, etc.\nnetwork = jx.Network([cell1, cell2])\n\n# Save.\nwith open(\"path/to/file.pkl\", \"wb\") as handle:\n    pickle.dump(network, handle)\n\n# Load.\nwith open(\"path/to/file.pkl\", \"rb\") as handle:\n    network = pickle.load(handle)\n

"},{"location":"faq/question_03/","title":"What kinds of models can be implemented in Jaxley?","text":"

Jaxley focuses on biophysical, Hodgkin-Huxley-type models. You can think of Jaxley like the NEURON simulator written in JAX.

Jaxley allows to simulate the following types of models, as well as networks thereof:

  • single-compartment (point neuron) Hodgkin-Huxley models
  • multi-compartment Hodgkin-Huxley models
  • rate-based neuron models

For all of these models, Jaxley is flexible and accurate. For example, it can flexibly add new channel models, use different kinds of synapses (conductance-based, tanh, \u2026), and it can insert different kinds of channels in different branches (or compartments) within single cells. Like NEURON, Jaxley implements a backward-Euler solver for stable numerical solution of multi-compartment neurons.

However, Jaxley does not implement the following types of models:

  • leaky-integrate and fire neurons
  • Ishikevich neuron models
  • etc\u2026
"},{"location":"reference/connect/","title":"Connecting Cells","text":""},{"location":"reference/connect/#jaxley.connect.connect","title":"connect(pre, post, synapse_type)","text":"

Connect two compartments with a chemical synapse.

The pre- and postsynaptic compartments must be different compartments of the same network.

Parameters:

Name Type Description Default pre View

View of the presynaptic compartment.

required post View

View of the postsynaptic compartment.

required synapse_type Synapse

The synapse to append

required Source code in jaxley/connect.py
def connect(\n    pre: \"View\",\n    post: \"View\",\n    synapse_type: \"Synapse\",\n):\n    \"\"\"Connect two compartments with a chemical synapse.\n\n    The pre- and postsynaptic compartments must be different compartments of the\n    same network.\n\n    Args:\n        pre: View of the presynaptic compartment.\n        post: View of the postsynaptic compartment.\n        synapse_type: The synapse to append\n    \"\"\"\n    assert is_same_network(\n        pre, post\n    ), \"Pre and post compartments must be part of the same network.\"\n\n    pre.base._append_multiple_synapses(pre.nodes, post.nodes, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.connectivity_matrix_connect","title":"connectivity_matrix_connect(pre_cell_view, post_cell_view, synapse_type, connectivity_matrix)","text":"

Appends multiple connections which build a custom connected network.

Connects pre- and postsynaptic cells according to a custom connectivity matrix. Entries > 0 in the matrix indicate a connection between the corresponding cells. Connections are from branch 0 location 0 to a randomly chosen branch and loc.

Parameters:

Name Type Description Default pre_cell_view View

View of the presynaptic cell.

required post_cell_view View

View of the postsynaptic cell.

required synapse_type Synapse

The synapse to append.

required connectivity_matrix ndarray[bool]

A boolean matrix indicating the connections between cells.

required Source code in jaxley/connect.py
def connectivity_matrix_connect(\n    pre_cell_view: \"View\",\n    post_cell_view: \"View\",\n    synapse_type: \"Synapse\",\n    connectivity_matrix: np.ndarray[bool],\n):\n    \"\"\"Appends multiple connections which build a custom connected network.\n\n    Connects pre- and postsynaptic cells according to a custom connectivity matrix.\n    Entries > 0 in the matrix indicate a connection between the corresponding cells.\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n        connectivity_matrix: A boolean matrix indicating the connections between cells.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    pre_cell_inds = pre_cell_view._cells_in_view\n    post_cell_inds = post_cell_view._cells_in_view\n    # setting scope ensure that this works indep of current scope\n    pre_nodes = pre_cell_view.scope(\"local\").branch(0).comp(0).nodes\n    pre_nodes[\"index\"] = pre_nodes.index\n    pre_cell_nodes = pre_nodes.set_index(\"global_cell_index\")\n\n    assert connectivity_matrix.shape == (\n        len(pre_cell_inds),\n        len(post_cell_inds),\n    ), \"Connectivity matrix must have shape (num_pre, num_post).\"\n    assert connectivity_matrix.dtype == bool, \"Connectivity matrix must be boolean.\"\n\n    # get connection pairs from connectivity matrix\n    from_idx, to_idx = np.where(connectivity_matrix)\n    pre_cell_inds = pre_cell_inds[from_idx]\n    post_cell_inds = post_cell_inds[to_idx]\n\n    # Sample random postsynaptic compartments (global comp indices).\n    global_post_indices = np.hstack(\n        [\n            sample_comp(post_cell_view.scope(\"global\").cell(cell_idx))\n            for cell_idx in post_cell_inds\n        ]\n    )\n    post_rows = post_cell_view.nodes.loc[global_post_indices]\n\n    # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n    global_pre_indices = pre_cell_nodes.loc[pre_cell_inds, \"index\"].to_numpy()\n    pre_rows = pre_cell_view.select(nodes=global_pre_indices).nodes\n\n    pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.fully_connect","title":"fully_connect(pre_cell_view, post_cell_view, synapse_type)","text":"

Appends multiple connections which build a fully connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch and loc.

Parameters:

Name Type Description Default pre_cell_view View

View of the presynaptic cell.

required post_cell_view View

View of the postsynaptic cell.

required synapse_type Synapse

The synapse to append.

required Source code in jaxley/connect.py
def fully_connect(\n    pre_cell_view: \"View\",\n    post_cell_view: \"View\",\n    synapse_type: \"Synapse\",\n):\n    \"\"\"Appends multiple connections which build a fully connected layer.\n\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    num_pre = len(pre_cell_view._cells_in_view)\n    num_post = len(post_cell_view._cells_in_view)\n\n    # Infer indices of (random) postsynaptic compartments.\n    global_post_indices = (\n        post_cell_view.nodes.groupby(\"global_cell_index\")\n        .sample(num_pre, replace=True)\n        .index.to_numpy()\n    )\n    global_post_indices = global_post_indices.reshape((-1, num_pre), order=\"F\").ravel()\n    post_rows = post_cell_view.nodes.loc[global_post_indices]\n\n    # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n    pre_rows = pre_cell_view.scope(\"local\").branch(0).comp(0).nodes.copy()\n    # Repeat rows `num_post` times. See SO 50788508.\n    pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)\n\n    pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.is_same_network","title":"is_same_network(pre, post)","text":"

Check if views are from the same network.

Source code in jaxley/connect.py
def is_same_network(pre: \"View\", post: \"View\") -> bool:\n    \"\"\"Check if views are from the same network.\"\"\"\n    is_in_net = \"network\" in pre.base.__class__.__name__.lower()\n    is_in_same_net = pre.base is post.base\n    return is_in_net and is_in_same_net\n
"},{"location":"reference/connect/#jaxley.connect.sample_comp","title":"sample_comp(cell_view, num=1, replace=True)","text":"

Sample a compartment from a cell.

Returns View with shape (num, num_cols).

Source code in jaxley/connect.py
def sample_comp(cell_view: \"View\", num: int = 1, replace=True) -> \"CompartmentView\":\n    \"\"\"Sample a compartment from a cell.\n\n    Returns View with shape (num, num_cols).\"\"\"\n    return np.random.choice(cell_view._comps_in_view, num, replace=replace)\n
"},{"location":"reference/connect/#jaxley.connect.sparse_connect","title":"sparse_connect(pre_cell_view, post_cell_view, synapse_type, p)","text":"

Appends multiple connections which build a sparse, randomly connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch and loc.

Parameters:

Name Type Description Default pre_cell_view View

View of the presynaptic cell.

required post_cell_view View

View of the postsynaptic cell.

required synapse_type Synapse

The synapse to append.

required p float

Probability of connection.

required Source code in jaxley/connect.py
def sparse_connect(\n    pre_cell_view: \"View\",\n    post_cell_view: \"View\",\n    synapse_type: \"Synapse\",\n    p: float,\n):\n    \"\"\"Appends multiple connections which build a sparse, randomly connected layer.\n\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n        p: Probability of connection.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    pre_cell_inds = pre_cell_view._cells_in_view\n    post_cell_inds = post_cell_view._cells_in_view\n    num_pre = len(pre_cell_inds)\n    num_post = len(post_cell_inds)\n\n    num_connections = np.random.binomial(num_pre * num_post, p)\n    pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)\n    post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)\n\n    # Sort the synapses only for convenience of inspecting `.edges`.\n    sorting = np.argsort(pre_syn_neurons)\n    pre_syn_neurons = pre_syn_neurons[sorting]\n    post_syn_neurons = post_syn_neurons[sorting]\n\n    # Post-synapse is a randomly chosen branch and compartment.\n    global_post_indices = [\n        sample_comp(post_cell_view.scope(\"global\").cell(cell_idx))\n        for cell_idx in post_syn_neurons\n    ]\n    global_post_indices = (\n        np.hstack(global_post_indices) if len(global_post_indices) > 1 else []\n    )\n    post_rows = post_cell_view.base.nodes.loc[global_post_indices]\n\n    # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n    global_pre_indices = pre_cell_view.base._cumsum_ncomp_per_cell[pre_syn_neurons]\n    pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices]\n\n    if len(pre_rows) > 0:\n        pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/integration/","title":"Simulation","text":""},{"location":"reference/integration/#jaxley.integrate.add_clamps","title":"add_clamps(externals, external_inds, data_clamps=None)","text":"

Adds clamps to the external inputs.

Parameters:

Name Type Description Default externals Dict

Current external inputs.

required external_inds Dict

Current external indices.

required data_clamps Optional[Tuple[str, ndarray, DataFrame]]

Additional data clamps. Defaults to None.

None

Returns:

Type Description Tuple[Dict, Dict]

Tuple[Dict, Dict]: Updated external inputs and indices.

Source code in jaxley/integrate.py
def add_clamps(\n    externals: Dict,\n    external_inds: Dict,\n    data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,\n) -> Tuple[Dict, Dict]:\n    \"\"\"Adds clamps to the external inputs.\n\n    Args:\n        externals (Dict): Current external inputs.\n        external_inds (Dict): Current external indices.\n        data_clamps (Optional[Tuple[str, jnp.ndarray, pd.DataFrame]], optional): Additional data clamps. Defaults to None.\n\n    Returns:\n        Tuple[Dict, Dict]: Updated external inputs and indices.\n    \"\"\"\n    # If a clamp is inserted, add it to the external inputs.\n    if data_clamps is not None:\n        state_name, clamps, inds = data_clamps\n        if state_name in externals.keys():\n            externals[state_name] = jnp.concatenate([externals[state_name], clamps])\n            external_inds[state_name] = jnp.concatenate(\n                [external_inds[state_name], inds.index.to_numpy()]\n            )\n        else:\n            externals[state_name] = clamps\n            external_inds[state_name] = inds.index.to_numpy()\n\n    return externals, external_inds\n
"},{"location":"reference/integration/#jaxley.integrate.add_stimuli","title":"add_stimuli(externals, external_inds, data_stimuli=None)","text":"

Extends the external inputs with the stimuli.

Parameters:

Name Type Description Default externals Dict

Current external inputs.

required external_inds Dict

Current external indices.

required data_stimuli Optional[Tuple[ndarray, DataFrame]]

Additional data stimuli. Defaults to None.

None

Returns:

Type Description Tuple[Dict, Dict]

Tuple[Dict, Dict]: Updated external inputs and indices.

Source code in jaxley/integrate.py
def add_stimuli(\n    externals: Dict,\n    external_inds: Dict,\n    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n) -> Tuple[Dict, Dict]:\n    \"\"\"Extends the external inputs with the stimuli.\n\n    Args:\n        externals (Dict): Current external inputs.\n        external_inds (Dict): Current external indices.\n        data_stimuli (Optional[Tuple[jnp.ndarray, pd.DataFrame]], optional): Additional data stimuli. Defaults to None.\n\n    Returns:\n        Tuple[Dict, Dict]: Updated external inputs and indices.\n    \"\"\"\n    # If stimulus is inserted, add it to the external inputs.\n    if \"i\" in externals.keys() or data_stimuli is not None:\n        if \"i\" in externals.keys():\n            if data_stimuli is not None:\n                externals[\"i\"] = jnp.concatenate([externals[\"i\"], data_stimuli[1]])\n                external_inds[\"i\"] = jnp.concatenate(\n                    [external_inds[\"i\"], data_stimuli[2].index.to_numpy()]\n                )\n        else:\n            externals[\"i\"] = data_stimuli[1]\n            external_inds[\"i\"] = data_stimuli[2].index.to_numpy()\n\n    return externals, external_inds\n
"},{"location":"reference/integration/#jaxley.integrate.build_init_and_step_fn","title":"build_init_and_step_fn(module, voltage_solver='jaxley.stone', solver='bwd_euler')","text":"

This function returns the init_fn and step_fn which initialize the parameters and states of the neuron model and then step through the model

Parameters:

Name Type Description Default module Module

A Module object that e.g. a cell.

required voltage_solver str

Voltage solver used in step. Defaults to \u201cjaxley.stone\u201d.

'jaxley.stone' solver str

ODE solver. Defaults to \u201cbwd_euler\u201d.

'bwd_euler'

Returns:

Type Description Tuple[Callable, Callable]

init_fn, step_fn: Functions that initialize the state and parameters, and perform a single integration step, respectively.

Source code in jaxley/integrate.py
def build_init_and_step_fn(\n    module: Module,\n    voltage_solver: str = \"jaxley.stone\",\n    solver: str = \"bwd_euler\",\n) -> Tuple[Callable, Callable]:\n    \"\"\"This function returns the `init_fn` and `step_fn` which initialize the\n    parameters and states of the neuron model and then step through the model\n\n    Args:\n        module (Module): A `Module` object that e.g. a cell.\n        voltage_solver (str, optional): Voltage solver used in step. Defaults to \"jaxley.stone\".\n        solver (str, optional): ODE solver. Defaults to \"bwd_euler\".\n\n    Returns:\n        init_fn, step_fn: Functions that initialize the state and parameters, and perform\n            a single integration step, respectively.\n    \"\"\"\n    # Initialize the external inputs and their indices.\n    external_inds = module.external_inds.copy()\n\n    def init_fn(\n        params: List[Dict[str, jnp.ndarray]],\n        all_states: Optional[Dict] = None,\n        param_state: Optional[List[Dict]] = None,\n        delta_t: float = 0.025,\n    ) -> Tuple[Dict, Dict]:\n        \"\"\"Initializes the parameters and states of the neuron model.\n\n        Args:\n            params (List[Dict[str, jnp.ndarray]]): List of trainable parameters.\n            all_states (Optional[Dict], optional): State if alread initialized. Defaults to None.\n            param_state (Optional[List[Dict]], optional): Parameters returned by `data_set`.. Defaults to None.\n            delta_t (float, optional): Step size. Defaults to 0.025.\n\n        Returns:\n            Tuple[Dict, Dict]: All states and parameters.\n        \"\"\"\n        # Make the `trainable_params` of the same shape as the `param_state`, such that\n        # they can be processed together by `get_all_parameters`.\n        pstate = params_to_pstate(params, module.indices_set_by_trainables)\n        if param_state is not None:\n            pstate += param_state\n\n        all_params = module.get_all_parameters(pstate, voltage_solver=voltage_solver)\n        all_states = (\n            module.get_all_states(pstate, all_params, delta_t)\n            if all_states is None\n            else all_states\n        )\n        return all_states, all_params\n\n    def step_fn(\n        all_states: Dict,\n        all_params: Dict,\n        externals: Dict,\n        external_inds: Dict = external_inds,\n        delta_t: float = 0.025,\n    ) -> Dict:\n        \"\"\"Performs a single integration step with step size delta_t.\n\n        Args:\n            all_states (Dict): Current state of the neuron model.\n            all_params (Dict): Current parameters of the neuron model.\n            externals (Dict): External inputs.\n            external_inds (Dict, optional): External indices. Defaults to `module.external_inds`.\n            delta_t (float, optional): Time step. Defaults to 0.025.\n\n        Returns:\n            Dict: Updated states.\n        \"\"\"\n        state = all_states\n        state = module.step(\n            state,\n            delta_t,\n            external_inds,\n            externals,\n            params=all_params,\n            solver=solver,\n            voltage_solver=voltage_solver,\n        )\n        return state\n\n    return init_fn, step_fn\n
"},{"location":"reference/integration/#jaxley.integrate.integrate","title":"integrate(module, params=[], *, param_state=None, data_stimuli=None, data_clamps=None, t_max=None, delta_t=0.025, solver='bwd_euler', voltage_solver='jaxley.stone', checkpoint_lengths=None, all_states=None, return_states=False)","text":"

Solves ODE and simulates neuron model.

Parameters:

Name Type Description Default params List[Dict[str, ndarray]]

Trainable parameters returned by get_parameters().

[] param_state Optional[List[Dict]]

Parameters returned by data_set.

None data_stimuli Optional[Tuple[ndarray, DataFrame]]

Outputs of .data_stimulate(), only needed if stimuli change across function calls.

None data_clamps Optional[Tuple[str, ndarray, DataFrame]]

Outputs of .data_clamp(), only needed if clamps change across function calls.

None t_max Optional[float]

Duration of the simulation in milliseconds. If t_max is greater than the length of the stimulus input, the stimulus will be padded at the end with zeros. If t_max is smaller, then the stimulus with be truncated.

None delta_t float

Time step of the solver in milliseconds.

0.025 solver str

Which ODE solver to use. Either of [\u201cfwd_euler\u201d, \u201cbwd_euler\u201d, \u201ccrank_nicolson\u201d].

'bwd_euler' tridiag_solver

Algorithm to solve tridiagonal systems. The different options only affect bwd_euler and crank_nicolson solvers. Either of [\u201cstone\u201d, \u201cthomas\u201d], where stone is much faster on GPU for long branches with many compartments and thomas is slightly faster on CPU (thomas is used in NEURON).

required checkpoint_lengths Optional[List[int]]

Number of timesteps at every level of checkpointing. The prod(checkpoint_lengths) must be larger or equal to the desired number of simulated timesteps. Warning: the simulation is run for prod(checkpoint_lengths) timesteps, and the result is posthoc truncated to the desired simulation length. Therefore, a poor choice of checkpoint_lengths can lead to longer simulation time. If None, no checkpointing is applied.

None all_states Optional[Dict]

An optional initial state that was returned by a previous jx.integrate(..., return_states=True) run. Overrides potentially trainable initial states.

None return_states bool

If True, it returns all states such that the current state of the Module can be set with set_states.

False Source code in jaxley/integrate.py
def integrate(\n    module: Module,\n    params: List[Dict[str, jnp.ndarray]] = [],\n    *,\n    param_state: Optional[List[Dict]] = None,\n    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n    data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,\n    t_max: Optional[float] = None,\n    delta_t: float = 0.025,\n    solver: str = \"bwd_euler\",\n    voltage_solver: str = \"jaxley.stone\",\n    checkpoint_lengths: Optional[List[int]] = None,\n    all_states: Optional[Dict] = None,\n    return_states: bool = False,\n) -> jnp.ndarray:\n    \"\"\"\n    Solves ODE and simulates neuron model.\n\n    Args:\n        params: Trainable parameters returned by `get_parameters()`.\n        param_state: Parameters returned by `data_set`.\n        data_stimuli: Outputs of `.data_stimulate()`, only needed if stimuli change\n            across function calls.\n        data_clamps: Outputs of `.data_clamp()`, only needed if clamps change across\n            function calls.\n        t_max: Duration of the simulation in milliseconds. If `t_max` is greater than\n            the length of the stimulus input, the stimulus will be padded at the end\n            with zeros. If `t_max` is smaller, then the stimulus with be truncated.\n        delta_t: Time step of the solver in milliseconds.\n        solver: Which ODE solver to use. Either of [\"fwd_euler\", \"bwd_euler\",\n            \"crank_nicolson\"].\n        tridiag_solver: Algorithm to solve tridiagonal systems. The  different options\n            only affect `bwd_euler` and `crank_nicolson` solvers. Either of [\"stone\",\n            \"thomas\"], where `stone` is much faster on GPU for long branches\n            with many compartments and `thomas` is slightly faster on CPU (`thomas` is\n            used in NEURON).\n        checkpoint_lengths: Number of timesteps at every level of checkpointing. The\n            `prod(checkpoint_lengths)` must be larger or equal to the desired number of\n            simulated timesteps. Warning: the simulation is run for\n            `prod(checkpoint_lengths)` timesteps, and the result is posthoc truncated\n            to the desired simulation length. Therefore, a poor choice of\n            `checkpoint_lengths` can lead to longer simulation time. If `None`, no\n            checkpointing is applied.\n        all_states: An optional initial state that was returned by a previous\n            `jx.integrate(..., return_states=True)` run. Overrides potentially\n            trainable initial states.\n        return_states: If True, it returns all states such that the current state of\n            the `Module` can be set with `set_states`.\n    \"\"\"\n\n    assert module.initialized, \"Module is not initialized, run `._initialize()`.\"\n    module.to_jax()  # Creates `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.\n\n    # Initialize the external inputs and their indices.\n    externals = module.externals.copy()\n    external_inds = module.external_inds.copy()\n\n    # If stimulus is inserted, add it to the external inputs.\n    externals, external_inds = add_stimuli(externals, external_inds, data_stimuli)\n\n    # If a clamp is inserted, add it to the external inputs.\n    externals, external_inds = add_clamps(externals, external_inds, data_clamps)\n\n    if not externals.keys():\n        # No stimulus was inserted and no clamp was set.\n        assert (\n            t_max is not None\n        ), \"If no stimulus or clamp are inserted you have to specify the simulation duration at `jx.integrate(..., t_max=)`.\"\n\n    for key in externals.keys():\n        externals[key] = externals[key].T  # Shape `(time, num_stimuli)`.\n\n    if module.recordings.empty:\n        raise ValueError(\"No recordings are set. Please set them.\")\n    rec_inds = module.recordings.rec_index.to_numpy()\n    rec_states = module.recordings.state.to_numpy()\n\n    # Shorten or pad stimulus depending on `t_max`.\n    if t_max is not None:\n        t_max_steps = int(t_max // delta_t + 1)\n\n        # Pad or truncate the stimulus.\n        for key in externals.keys():\n            if t_max_steps > externals[key].shape[0]:\n                if key == \"i\":\n                    pad = jnp.zeros(\n                        (t_max_steps - externals[\"i\"].shape[0], externals[\"i\"].shape[1])\n                    )\n                    externals[\"i\"] = jnp.concatenate((externals[\"i\"], pad))\n                else:\n                    raise NotImplementedError(\n                        \"clamp must be at least as long as simulation.\"\n                    )\n            else:\n                externals[key] = externals[key][:t_max_steps, :]\n\n    init_fn, step_fn = build_init_and_step_fn(\n        module, voltage_solver=voltage_solver, solver=solver\n    )\n    all_states, all_params = init_fn(params, all_states, param_state, delta_t)\n\n    def _body_fun(state, externals):\n        state = step_fn(state, all_params, externals, external_inds, delta_t)\n        recs = jnp.asarray(\n            [\n                state[rec_state][rec_ind]\n                for rec_state, rec_ind in zip(rec_states, rec_inds)\n            ]\n        )\n        return state, recs\n\n    # If necessary, pad the stimulus with zeros in order to simulate sufficiently long.\n    # The total simulation length will be `prod(checkpoint_lengths)`. At the end, we\n    # return only the first `nsteps_to_return` elements (plus the initial state).\n    if externals:\n        example_key = list(externals.keys())[0]\n        nsteps_to_return = len(externals[example_key])\n    else:\n        nsteps_to_return = t_max_steps\n\n    if checkpoint_lengths is None:\n        checkpoint_lengths = [nsteps_to_return]\n        length = nsteps_to_return\n    else:\n        length = prod(checkpoint_lengths)\n        size_difference = length - nsteps_to_return\n        assert (\n            nsteps_to_return <= length\n        ), \"The desired simulation duration is longer than `prod(nested_length)`.\"\n        if externals:\n            dummy_external = jnp.zeros(\n                (size_difference, externals[example_key].shape[1])\n            )\n            for key in externals.keys():\n                externals[key] = jnp.concatenate([externals[key], dummy_external])\n\n    # Record the initial state.\n    init_recs = jnp.asarray(\n        [\n            all_states[rec_state][rec_ind]\n            for rec_state, rec_ind in zip(rec_states, rec_inds)\n        ]\n    )\n    init_recording = jnp.expand_dims(init_recs, axis=0)\n\n    # Run simulation.\n    all_states, recordings = nested_checkpoint_scan(\n        _body_fun,\n        all_states,\n        externals,\n        length=length,\n        nested_lengths=checkpoint_lengths,\n    )\n    recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T\n    return (recs, all_states) if return_states else recs\n
"},{"location":"reference/integration/#jaxley.solver_gate.exponential_euler","title":"exponential_euler(x, dt, x_inf, x_tau)","text":"

An exact solver for the linear dynamical system dx = -(x - x_inf) / x_tau.

Source code in jaxley/solver_gate.py
def exponential_euler(\n    x: jnp.ndarray,\n    dt: float,\n    x_inf: jnp.ndarray,\n    x_tau: jnp.ndarray,\n):\n    \"\"\"An exact solver for the linear dynamical system `dx = -(x - x_inf) / x_tau`.\"\"\"\n    exp_term = save_exp(-dt / x_tau)\n    return x * exp_term + x_inf * (1.0 - exp_term)\n
"},{"location":"reference/integration/#jaxley.solver_gate.save_exp","title":"save_exp(x, max_value=20.0)","text":"

Clip the input to a maximum value and return its exponential.

Source code in jaxley/solver_gate.py
def save_exp(x, max_value: float = 20.0):\n    \"\"\"Clip the input to a maximum value and return its exponential.\"\"\"\n    x = jnp.clip(x, a_max=max_value)\n    return jnp.exp(x)\n
"},{"location":"reference/integration/#jaxley.solver_gate.solve_inf_gate_exponential","title":"solve_inf_gate_exponential(x, dt, s_inf, tau_s)","text":"

solves dx/dt = (s_inf - x) / tau_s via exponential Euler

Parameters:

Name Type Description Default x ndarray

gate variable

required dt float

time_delta

required s_inf ndarray

description

required tau_s ndarray

description

required

Returns:

Name Type Description _type_

updated gate

Source code in jaxley/solver_gate.py
def solve_inf_gate_exponential(\n    x: jnp.ndarray,\n    dt: float,\n    s_inf: jnp.ndarray,\n    tau_s: jnp.ndarray,\n):\n    \"\"\"solves dx/dt = (s_inf - x) / tau_s\n    via exponential Euler\n\n    Args:\n        x (jnp.ndarray): gate variable\n        dt (float): time_delta\n        s_inf (jnp.ndarray): _description_\n        tau_s (jnp.ndarray): _description_\n\n    Returns:\n        _type_: updated gate\n    \"\"\"\n    slope = -1.0 / tau_s\n    exp_term = save_exp(slope * dt)\n    return x * exp_term + s_inf * (1.0 - exp_term)\n
"},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_explicit","title":"step_voltage_explicit(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, ncomp_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)","text":"

Solve one timestep of branched nerve equations with explicit (forward) Euler.

Source code in jaxley/solver_voltage.py
def step_voltage_explicit(\n    voltages: jnp.ndarray,\n    voltage_terms: jnp.ndarray,\n    constant_terms: jnp.ndarray,\n    axial_conductances: jnp.ndarray,\n    internal_node_inds: jnp.ndarray,\n    sinks: jnp.ndarray,\n    sources: jnp.ndarray,\n    types: jnp.ndarray,\n    ncomp_per_branch: jnp.ndarray,\n    par_inds: jnp.ndarray,\n    child_inds: jnp.ndarray,\n    nbranches: int,\n    solver: str,\n    delta_t: float,\n    idx: JaxleySolveIndexer,\n    debug_states,\n) -> jnp.ndarray:\n    \"\"\"Solve one timestep of branched nerve equations with explicit (forward) Euler.\"\"\"\n    voltages = jnp.reshape(voltages, (nbranches, -1))\n    voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))\n    constant_terms = jnp.reshape(constant_terms, (nbranches, -1))\n\n    update = _voltage_vectorfield(\n        voltages,\n        voltage_terms,\n        constant_terms,\n        types,\n        sources,\n        sinks,\n        axial_conductances,\n        par_inds,\n        child_inds,\n        nbranches,\n        solver,\n        delta_t,\n        idx,\n        debug_states,\n    )\n    new_voltates = voltages + delta_t * update\n    return new_voltates.ravel(order=\"C\")\n
"},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_implicit_with_jaxley_spsolve","title":"step_voltage_implicit_with_jaxley_spsolve(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, ncomp_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)","text":"

Solve one timestep of branched nerve equations with implicit (backward) Euler.

Source code in jaxley/solver_voltage.py
def step_voltage_implicit_with_jaxley_spsolve(\n    voltages: jnp.ndarray,\n    voltage_terms: jnp.ndarray,\n    constant_terms: jnp.ndarray,\n    axial_conductances: jnp.ndarray,\n    internal_node_inds: jnp.ndarray,\n    sinks: jnp.ndarray,\n    sources: jnp.ndarray,\n    types: jnp.ndarray,\n    ncomp_per_branch: jnp.ndarray,\n    par_inds: jnp.ndarray,\n    child_inds: jnp.ndarray,\n    nbranches: int,\n    solver: str,\n    delta_t: float,\n    idx: JaxleySolveIndexer,\n    debug_states,\n):\n    \"\"\"Solve one timestep of branched nerve equations with implicit (backward) Euler.\"\"\"\n    # Build diagonals.\n    c2c = np.isin(types, [0, 1, 2])\n    total_ncomp = idx.cumsum_ncomp[-1]\n    diags = jnp.ones(total_ncomp)\n\n    # if-case needed because `.at` does not allow empty inputs, but the input is\n    # empty for compartments.\n    if len(sinks[c2c]) > 0:\n        diags = diags.at[idx.mask(sinks[c2c])].add(delta_t * axial_conductances[c2c])\n\n    diags = diags.at[idx.mask(internal_node_inds)].add(delta_t * voltage_terms)\n\n    # Build solves.\n    solves = jnp.zeros(total_ncomp)\n    solves = solves.at[idx.mask(internal_node_inds)].add(\n        voltages + delta_t * constant_terms\n    )\n\n    # Build upper and lower within the branch.\n    c2c = types == 0  # c2c = compartment-to-compartment.\n\n    # Build uppers.\n    uppers = jnp.zeros(total_ncomp)\n    upper_inds = sources[c2c] > sinks[c2c]\n    sinks_upper = sinks[c2c][upper_inds]\n    if len(sinks_upper) > 0:\n        uppers = uppers.at[idx.mask(sinks_upper)].add(\n            -delta_t * axial_conductances[c2c][upper_inds]\n        )\n\n    # Build lowers.\n    lowers = jnp.zeros(total_ncomp)\n    lower_inds = sources[c2c] < sinks[c2c]\n    sinks_lower = sinks[c2c][lower_inds]\n    if len(sinks_lower) > 0:\n        lowers = lowers.at[idx.mask(sinks_lower)].add(\n            -delta_t * axial_conductances[c2c][lower_inds]\n        )\n\n    # Build branchpoint conductances.\n    branchpoint_conds_parents = axial_conductances[types == 1]\n    branchpoint_conds_children = axial_conductances[types == 2]\n    branchpoint_weights_parents = axial_conductances[types == 3]\n    branchpoint_weights_children = axial_conductances[types == 4]\n    all_branchpoint_vals = jnp.concatenate(\n        [branchpoint_weights_parents, branchpoint_weights_children]\n    )\n    # Find unique group identifiers\n    num_branchpoints = len(branchpoint_conds_parents)\n    branchpoint_diags = -group_and_sum(\n        all_branchpoint_vals, idx.branchpoint_group_inds, num_branchpoints\n    )\n    branchpoint_solves = jnp.zeros((num_branchpoints,))\n\n    branchpoint_conds_children = -delta_t * branchpoint_conds_children\n    branchpoint_conds_parents = -delta_t * branchpoint_conds_parents\n\n    # Here, I move all child and parent indices towards a branchpoint into a larger\n    # vector. This is wasteful, but it makes indexing much easier. JIT compiling\n    # makes the speed difference negligible.\n    # Children.\n    bp_conds_children = jnp.zeros(nbranches)\n    bp_weights_children = jnp.zeros(nbranches)\n    # Parents.\n    bp_conds_parents = jnp.zeros(nbranches)\n    bp_weights_parents = jnp.zeros(nbranches)\n\n    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n    # `len(inds) == 0` is the case for branches and compartments.\n    if num_branchpoints > 0:\n        bp_conds_children = bp_conds_children.at[child_inds].set(\n            branchpoint_conds_children\n        )\n        bp_weights_children = bp_weights_children.at[child_inds].set(\n            branchpoint_weights_children\n        )\n        bp_conds_parents = bp_conds_parents.at[par_inds].set(branchpoint_conds_parents)\n        bp_weights_parents = bp_weights_parents.at[par_inds].set(\n            branchpoint_weights_parents\n        )\n\n    # Triangulate the linear system of equations.\n    (\n        diags,\n        lowers,\n        solves,\n        uppers,\n        branchpoint_diags,\n        branchpoint_solves,\n        bp_weights_children,\n        bp_conds_parents,\n    ) = _triang_branched(\n        lowers,\n        diags,\n        uppers,\n        solves,\n        bp_conds_children,\n        bp_conds_parents,\n        bp_weights_children,\n        bp_weights_parents,\n        branchpoint_diags,\n        branchpoint_solves,\n        solver,\n        ncomp_per_branch,\n        idx,\n        debug_states,\n    )\n\n    # Backsubstitute the linear system of equations.\n    (\n        solves,\n        lowers,\n        diags,\n        bp_weights_parents,\n        branchpoint_solves,\n        bp_conds_children,\n    ) = _backsub_branched(\n        lowers,\n        diags,\n        uppers,\n        solves,\n        bp_conds_children,\n        bp_conds_parents,\n        bp_weights_children,\n        bp_weights_parents,\n        branchpoint_diags,\n        branchpoint_solves,\n        solver,\n        ncomp_per_branch,\n        idx,\n        debug_states,\n    )\n    return solves.ravel(order=\"C\")[idx.mask(internal_node_inds)]\n
"},{"location":"reference/mechanisms/","title":"Channels","text":""},{"location":"reference/mechanisms/#channel","title":"Channel","text":"

Channel base class. All channels inherit from this class.

As in NEURON, a Channel is considered a distributed process, which means that its conductances are to be specified in S/cm2 and its currents are to be specified in uA/cm2.

Source code in jaxley/channels/channel.py
class Channel:\n    \"\"\"Channel base class. All channels inherit from this class.\n\n    As in NEURON, a `Channel` is considered a distributed process, which means that its\n    conductances are to be specified in `S/cm2` and its currents are to be specified in\n    `uA/cm2`.\"\"\"\n\n    _name = None\n    channel_params = None\n    channel_states = None\n    current_name = None\n\n    def __init__(self, name: Optional[str] = None):\n        contact = (\n            \"If you have any questions, please reach out via email to \"\n            \"michael.deistler@uni-tuebingen.de or create an issue on Github: \"\n            \"https://github.com/jaxleyverse/jaxley/issues. Thank you!\"\n        )\n        if (\n            not hasattr(self, \"current_is_in_mA_per_cm2\")\n            or not self.current_is_in_mA_per_cm2\n        ):\n            raise ValueError(\n                \"The channel you are using is deprecated. \"\n                \"In Jaxley version 0.5.0, we changed the unit of the current returned \"\n                \"by `compute_current` of channels from `uA/cm^2` to `mA/cm^2`. Please \"\n                \"update your channel model (by dividing the resulting current by 1000) \"\n                \"and set `self.current_is_in_mA_per_cm2=True` as the first line \"\n                f\"in the `__init__()` method of your channel. {contact}\"\n            )\n\n        self._name = name if name else self.__class__.__name__\n\n    @property\n    def name(self) -> Optional[str]:\n        \"\"\"The name of the channel (by default, this is the class name).\"\"\"\n        return self._name\n\n    def change_name(self, new_name: str):\n        \"\"\"Change the channel name.\n\n        Args:\n            new_name: The new name of the channel.\n\n        Returns:\n            Renamed channel, such that this function is chainable.\n        \"\"\"\n        old_prefix = self._name + \"_\"\n        new_prefix = new_name + \"_\"\n\n        self._name = new_name\n        self.channel_params = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.channel_params.items()\n        }\n\n        self.channel_states = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.channel_states.items()\n        }\n        return self\n\n    def update_states(\n        self, states, dt, v, params\n    ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Return the updated states.\"\"\"\n        raise NotImplementedError\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Given channel states and voltage, return the current through the channel.\n\n        Args:\n            states: All states of the compartment.\n            v: Voltage of the compartment in mV.\n            params: Parameters of the channel (conductances in `S/cm2`).\n\n        Returns:\n            Current in `uA/cm2`.\n        \"\"\"\n        raise NotImplementedError\n\n    def init_state(\n        self,\n        states: Dict[str, jnp.ndarray],\n        v: jnp.ndarray,\n        params: Dict[str, jnp.ndarray],\n        delta_t: float,\n    ):\n        \"\"\"Initialize states of channel.\"\"\"\n        return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.name","title":"name: Optional[str] property","text":"

The name of the channel (by default, this is the class name).

"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.change_name","title":"change_name(new_name)","text":"

Change the channel name.

Parameters:

Name Type Description Default new_name str

The new name of the channel.

required

Returns:

Type Description

Renamed channel, such that this function is chainable.

Source code in jaxley/channels/channel.py
def change_name(self, new_name: str):\n    \"\"\"Change the channel name.\n\n    Args:\n        new_name: The new name of the channel.\n\n    Returns:\n        Renamed channel, such that this function is chainable.\n    \"\"\"\n    old_prefix = self._name + \"_\"\n    new_prefix = new_name + \"_\"\n\n    self._name = new_name\n    self.channel_params = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.channel_params.items()\n    }\n\n    self.channel_states = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.channel_states.items()\n    }\n    return self\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.compute_current","title":"compute_current(states, v, params)","text":"

Given channel states and voltage, return the current through the channel.

Parameters:

Name Type Description Default states Dict[str, ndarray]

All states of the compartment.

required v

Voltage of the compartment in mV.

required params Dict[str, ndarray]

Parameters of the channel (conductances in S/cm2).

required

Returns:

Type Description

Current in uA/cm2.

Source code in jaxley/channels/channel.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Given channel states and voltage, return the current through the channel.\n\n    Args:\n        states: All states of the compartment.\n        v: Voltage of the compartment in mV.\n        params: Parameters of the channel (conductances in `S/cm2`).\n\n    Returns:\n        Current in `uA/cm2`.\n    \"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize states of channel.

Source code in jaxley/channels/channel.py
def init_state(\n    self,\n    states: Dict[str, jnp.ndarray],\n    v: jnp.ndarray,\n    params: Dict[str, jnp.ndarray],\n    delta_t: float,\n):\n    \"\"\"Initialize states of channel.\"\"\"\n    return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.update_states","title":"update_states(states, dt, v, params)","text":"

Return the updated states.

Source code in jaxley/channels/channel.py
def update_states(\n    self, states, dt, v, params\n) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n    \"\"\"Return the updated states.\"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/mechanisms/#hh","title":"HH","text":"

Bases: Channel

Hodgkin-Huxley channel.

Source code in jaxley/channels/hh.py
class HH(Channel):\n    \"\"\"Hodgkin-Huxley channel.\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gNa\": 0.12,\n            f\"{prefix}_gK\": 0.036,\n            f\"{prefix}_gLeak\": 0.0003,\n            f\"{prefix}_eNa\": 50.0,\n            f\"{prefix}_eK\": -77.0,\n            f\"{prefix}_eLeak\": -54.3,\n        }\n        self.channel_states = {\n            f\"{prefix}_m\": 0.2,\n            f\"{prefix}_h\": 0.2,\n            f\"{prefix}_n\": 0.2,\n        }\n        self.current_name = f\"i_HH\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Return updated HH channel state.\"\"\"\n        prefix = self._name\n        m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n        new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n        new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n        new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n        return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current through HH channels.\"\"\"\n        prefix = self._name\n        m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n        gNa = params[f\"{prefix}_gNa\"] * (m**3) * h  # S/cm^2\n        gK = params[f\"{prefix}_gK\"] * n**4  # S/cm^2\n        gLeak = params[f\"{prefix}_gLeak\"]  # S/cm^2\n\n        return (\n            gNa * (v - params[f\"{prefix}_eNa\"])\n            + gK * (v - params[f\"{prefix}_eK\"])\n            + gLeak * (v - params[f\"{prefix}_eLeak\"])\n        )\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_m, beta_m = self.m_gate(v)\n        alpha_h, beta_h = self.h_gate(v)\n        alpha_n, beta_n = self.n_gate(v)\n        return {\n            f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n            f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n            f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n        }\n\n    @staticmethod\n    def m_gate(v):\n        alpha = 0.1 * _vtrap(-(v + 40), 10)\n        beta = 4.0 * save_exp(-(v + 65) / 18)\n        return alpha, beta\n\n    @staticmethod\n    def h_gate(v):\n        alpha = 0.07 * save_exp(-(v + 65) / 20)\n        beta = 1.0 / (save_exp(-(v + 35) / 10) + 1)\n        return alpha, beta\n\n    @staticmethod\n    def n_gate(v):\n        alpha = 0.01 * _vtrap(-(v + 55), 10)\n        beta = 0.125 * save_exp(-(v + 65) / 80)\n        return alpha, beta\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.compute_current","title":"compute_current(states, v, params)","text":"

Return current through HH channels.

Source code in jaxley/channels/hh.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current through HH channels.\"\"\"\n    prefix = self._name\n    m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n    gNa = params[f\"{prefix}_gNa\"] * (m**3) * h  # S/cm^2\n    gK = params[f\"{prefix}_gK\"] * n**4  # S/cm^2\n    gLeak = params[f\"{prefix}_gLeak\"]  # S/cm^2\n\n    return (\n        gNa * (v - params[f\"{prefix}_eNa\"])\n        + gK * (v - params[f\"{prefix}_eK\"])\n        + gLeak * (v - params[f\"{prefix}_eLeak\"])\n    )\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/hh.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_m, beta_m = self.m_gate(v)\n    alpha_h, beta_h = self.h_gate(v)\n    alpha_n, beta_n = self.n_gate(v)\n    return {\n        f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n        f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n        f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n    }\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.update_states","title":"update_states(states, dt, v, params)","text":"

Return updated HH channel state.

Source code in jaxley/channels/hh.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Return updated HH channel state.\"\"\"\n    prefix = self._name\n    m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n    new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n    new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n    new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n    return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n
"},{"location":"reference/mechanisms/#pospischil","title":"Pospischil","text":"

Bases: Channel

Leak current

Source code in jaxley/channels/pospischil.py
class Leak(Channel):\n    \"\"\"Leak current\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gLeak\": 1e-4,\n            f\"{prefix}_eLeak\": -70.0,\n        }\n        self.channel_states = {}\n        self.current_name = f\"i_{prefix}\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"No state to update.\"\"\"\n        return {}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        gLeak = params[f\"{prefix}_gLeak\"]  # S/cm^2\n        return gLeak * (v - params[f\"{prefix}_eLeak\"])\n\n    def init_state(self, states, v, params, delta_t):\n        return {}\n

Bases: Channel

Sodium channel

Source code in jaxley/channels/pospischil.py
class Na(Channel):\n    \"\"\"Sodium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gNa\": 50e-3,\n            \"eNa\": 50.0,\n            \"vt\": -60.0,  # Global parameter, not prefixed with `Na`.\n        }\n        self.channel_states = {f\"{prefix}_m\": 0.2, f\"{prefix}_h\": 0.2}\n        self.current_name = f\"i_Na\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n        new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n        new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n        return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n        gNa = params[f\"{prefix}_gNa\"] * (m**3) * h  # S/cm^2\n\n        current = gNa * (v - params[\"eNa\"])\n        return current\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n        alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n        return {\n            f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n            f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n        }\n\n    @staticmethod\n    def m_gate(v, vt):\n        v_alpha = v - vt - 13.0\n        alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25\n\n        v_beta = v - vt - 40.0\n        beta = 0.28 * efun(0.2 * v_beta) / 0.2\n        return alpha, beta\n\n    @staticmethod\n    def h_gate(v, vt):\n        v_alpha = v - vt - 17.0\n        alpha = 0.128 * save_exp(-v_alpha / 18.0)\n\n        v_beta = v - vt - 40.0\n        beta = 4.0 / (save_exp(-v_beta / 5.0) + 1.0)\n        return alpha, beta\n

Bases: Channel

Potassium channel

Source code in jaxley/channels/pospischil.py
class K(Channel):\n    \"\"\"Potassium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gK\": 5e-3,\n            \"eK\": -90.0,\n            \"vt\": -60.0,  # Global parameter, not prefixed with `Na`.\n        }\n        self.channel_states = {f\"{prefix}_n\": 0.2}\n        self.current_name = f\"i_K\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        n = states[f\"{prefix}_n\"]\n        new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n        return {f\"{prefix}_n\": new_n}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        n = states[f\"{prefix}_n\"]\n\n        gK = params[f\"{prefix}_gK\"] * (n**4)  # S/cm^2\n\n        return gK * (v - params[\"eK\"])\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n        return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n\n    @staticmethod\n    def n_gate(v, vt):\n        v_alpha = v - vt - 15.0\n        alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2\n\n        v_beta = v - vt - 10.0\n        beta = 0.5 * save_exp(-v_beta / 40.0)\n        return alpha, beta\n

Bases: Channel

Slow M Potassium channel

Source code in jaxley/channels/pospischil.py
class Km(Channel):\n    \"\"\"Slow M Potassium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gKm\": 0.004e-3,\n            f\"{prefix}_taumax\": 4000.0,\n            f\"eK\": -90.0,\n        }\n        self.channel_states = {f\"{prefix}_p\": 0.2}\n        self.current_name = f\"i_K\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        p = states[f\"{prefix}_p\"]\n        new_p = solve_inf_gate_exponential(\n            p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n        )\n        return {f\"{prefix}_p\": new_p}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        p = states[f\"{prefix}_p\"]\n\n        gKm = params[f\"{prefix}_gKm\"] * p  # S/cm^2\n        return gKm * (v - params[\"eK\"])\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n        return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n\n    @staticmethod\n    def p_gate(v, taumax):\n        v_p = v + 35.0\n        p_inf = 1.0 / (1.0 + save_exp(-0.1 * v_p))\n\n        tau_p = taumax / (3.3 * save_exp(0.05 * v_p) + save_exp(-0.05 * v_p))\n\n        return p_inf, tau_p\n

Bases: Channel

L-type Calcium channel

Source code in jaxley/channels/pospischil.py
class CaL(Channel):\n    \"\"\"L-type Calcium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gCaL\": 0.1e-3,\n            \"eCa\": 120.0,\n        }\n        self.channel_states = {f\"{prefix}_q\": 0.2, f\"{prefix}_r\": 0.2}\n        self.current_name = f\"i_Ca\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n        new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n        new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n        return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n        gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r  # S/cm^2\n\n        return gCaL * (v - params[\"eCa\"])\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_q, beta_q = self.q_gate(v)\n        alpha_r, beta_r = self.r_gate(v)\n        return {\n            f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n            f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n        }\n\n    @staticmethod\n    def q_gate(v):\n        v_alpha = -v - 27.0\n        alpha = 0.055 * efun(v_alpha / 3.8) * 3.8\n\n        v_beta = -v - 75.0\n        beta = 0.94 * save_exp(v_beta / 17.0)\n        return alpha, beta\n\n    @staticmethod\n    def r_gate(v):\n        v_alpha = -v - 13.0\n        alpha = 0.000457 * save_exp(v_alpha / 50)\n\n        v_beta = -v - 15.0\n        beta = 0.0065 / (save_exp(v_beta / 28.0) + 1)\n        return alpha, beta\n

Bases: Channel

T-type Calcium channel

Source code in jaxley/channels/pospischil.py
class CaT(Channel):\n    \"\"\"T-type Calcium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gCaT\": 0.4e-4,\n            f\"{prefix}_vx\": 2.0,\n            \"eCa\": 120.0,  # Global parameter, not prefixed with `CaT`.\n        }\n        self.channel_states = {f\"{prefix}_u\": 0.2}\n        self.current_name = f\"i_Ca\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        u = states[f\"{prefix}_u\"]\n        new_u = solve_inf_gate_exponential(\n            u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n        )\n        return {f\"{prefix}_u\": new_u}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        u = states[f\"{prefix}_u\"]\n        s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n        gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u  # S/cm^2\n\n        return gCaT * (v - params[\"eCa\"])\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n        return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n\n    @staticmethod\n    def u_gate(v, vx):\n        v_u1 = v + vx + 81.0\n        u_inf = 1.0 / (1.0 + save_exp(v_u1 / 4))\n\n        tau_u = (30.8 + (211.4 + save_exp((v + vx + 113.2) / 5.0))) / (\n            3.7 * (1 + save_exp((v + vx + 84.0) / 3.2))\n        )\n\n        return u_inf, tau_u\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    gLeak = params[f\"{prefix}_gLeak\"]  # S/cm^2\n    return gLeak * (v - params[f\"{prefix}_eLeak\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.update_states","title":"update_states(states, dt, v, params)","text":"

No state to update.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"No state to update.\"\"\"\n    return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n    gNa = params[f\"{prefix}_gNa\"] * (m**3) * h  # S/cm^2\n\n    current = gNa * (v - params[\"eNa\"])\n    return current\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n    alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n    return {\n        f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n        f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n    }\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.update_states","title":"update_states(states, dt, v, params)","text":"

Update state.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n    new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n    new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n    return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    n = states[f\"{prefix}_n\"]\n\n    gK = params[f\"{prefix}_gK\"] * (n**4)  # S/cm^2\n\n    return gK * (v - params[\"eK\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n    return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.update_states","title":"update_states(states, dt, v, params)","text":"

Update state.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    n = states[f\"{prefix}_n\"]\n    new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n    return {f\"{prefix}_n\": new_n}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    p = states[f\"{prefix}_p\"]\n\n    gKm = params[f\"{prefix}_gKm\"] * p  # S/cm^2\n    return gKm * (v - params[\"eK\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n    return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.update_states","title":"update_states(states, dt, v, params)","text":"

Update state.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    p = states[f\"{prefix}_p\"]\n    new_p = solve_inf_gate_exponential(\n        p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n    )\n    return {f\"{prefix}_p\": new_p}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n    gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r  # S/cm^2\n\n    return gCaL * (v - params[\"eCa\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_q, beta_q = self.q_gate(v)\n    alpha_r, beta_r = self.r_gate(v)\n    return {\n        f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n        f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n    }\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.update_states","title":"update_states(states, dt, v, params)","text":"

Update state.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n    new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n    new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n    return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    u = states[f\"{prefix}_u\"]\n    s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n    gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u  # S/cm^2\n\n    return gCaT * (v - params[\"eCa\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n    return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.update_states","title":"update_states(states, dt, v, params)","text":"

Update state.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    u = states[f\"{prefix}_u\"]\n    new_u = solve_inf_gate_exponential(\n        u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n    )\n    return {f\"{prefix}_u\": new_u}\n
"},{"location":"reference/mechanisms/#synapses","title":"Synapses","text":""},{"location":"reference/mechanisms/#synapse","title":"Synapse","text":"

Base class for a synapse.

As in NEURON, a Synapse is considered a point process, which means that its conductances are to be specified in uS and its currents are to be specified in nA.

Source code in jaxley/synapses/synapse.py
class Synapse:\n    \"\"\"Base class for a synapse.\n\n    As in NEURON, a `Synapse` is considered a point process, which means that its\n    conductances are to be specified in `uS` and its currents are to be specified in\n    `nA`.\n    \"\"\"\n\n    _name = None\n    synapse_params = None\n    synapse_states = None\n\n    def __init__(self, name: Optional[str] = None):\n        self._name = name if name else self.__class__.__name__\n\n    @property\n    def name(self) -> Optional[str]:\n        return self._name\n\n    def change_name(self, new_name: str):\n        \"\"\"Change the synapse name.\n\n        Args:\n            new_name: The new name of the channel.\n\n        Returns:\n            Renamed channel, such that this function is chainable.\n        \"\"\"\n        old_prefix = self._name + \"_\"\n        new_prefix = new_name + \"_\"\n\n        self._name = new_name\n        self.synapse_params = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.synapse_params.items()\n        }\n\n        self.synapse_states = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.synapse_states.items()\n        }\n        return self\n\n    def update_states(\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        pre_voltage: jnp.ndarray,\n        post_voltage: jnp.ndarray,\n        params: Dict[str, jnp.ndarray],\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"ODE update step.\n\n        Args:\n            states: States of the synapse.\n            delta_t: Time step in `ms`.\n            pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n            post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n            params: Parameters of the synapse. Conductances in `uS`.\n\n        Returns:\n            Updated states.\"\"\"\n        raise NotImplementedError\n\n    def compute_current(\n        states: Dict[str, jnp.ndarray],\n        pre_voltage: jnp.ndarray,\n        post_voltage: jnp.ndarray,\n        params: Dict[str, jnp.ndarray],\n    ) -> jnp.ndarray:\n        \"\"\"Return current through one synapse in `nA`.\n\n        Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n        Args:\n            states: States of the synapse.\n            pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n            post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n            params: Parameters of the synapse. Conductances in `uS`.\n\n        Returns:\n            Current through the synapse in `nA`, shape `()`.\n        \"\"\"\n        raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.change_name","title":"change_name(new_name)","text":"

Change the synapse name.

Parameters:

Name Type Description Default new_name str

The new name of the channel.

required

Returns:

Type Description

Renamed channel, such that this function is chainable.

Source code in jaxley/synapses/synapse.py
def change_name(self, new_name: str):\n    \"\"\"Change the synapse name.\n\n    Args:\n        new_name: The new name of the channel.\n\n    Returns:\n        Renamed channel, such that this function is chainable.\n    \"\"\"\n    old_prefix = self._name + \"_\"\n    new_prefix = new_name + \"_\"\n\n    self._name = new_name\n    self.synapse_params = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.synapse_params.items()\n    }\n\n    self.synapse_states = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.synapse_states.items()\n    }\n    return self\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)","text":"

Return current through one synapse in nA.

Internally, we use jax.vmap to vectorize this function across many synapses.

Parameters:

Name Type Description Default states Dict[str, ndarray]

States of the synapse.

required pre_voltage ndarray

Voltage of the presynaptic compartment, shape ().

required post_voltage ndarray

Voltage of the postsynaptic compartment, shape ().

required params Dict[str, ndarray]

Parameters of the synapse. Conductances in uS.

required

Returns:

Type Description ndarray

Current through the synapse in nA, shape ().

Source code in jaxley/synapses/synapse.py
def compute_current(\n    states: Dict[str, jnp.ndarray],\n    pre_voltage: jnp.ndarray,\n    post_voltage: jnp.ndarray,\n    params: Dict[str, jnp.ndarray],\n) -> jnp.ndarray:\n    \"\"\"Return current through one synapse in `nA`.\n\n    Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n    Args:\n        states: States of the synapse.\n        pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n        post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n        params: Parameters of the synapse. Conductances in `uS`.\n\n    Returns:\n        Current through the synapse in `nA`, shape `()`.\n    \"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

ODE update step.

Parameters:

Name Type Description Default states Dict[str, ndarray]

States of the synapse.

required delta_t float

Time step in ms.

required pre_voltage ndarray

Voltage of the presynaptic compartment, shape ().

required post_voltage ndarray

Voltage of the postsynaptic compartment, shape ().

required params Dict[str, ndarray]

Parameters of the synapse. Conductances in uS.

required

Returns:

Type Description Dict[str, ndarray]

Updated states.

Source code in jaxley/synapses/synapse.py
def update_states(\n    states: Dict[str, jnp.ndarray],\n    delta_t: float,\n    pre_voltage: jnp.ndarray,\n    post_voltage: jnp.ndarray,\n    params: Dict[str, jnp.ndarray],\n) -> Dict[str, jnp.ndarray]:\n    \"\"\"ODE update step.\n\n    Args:\n        states: States of the synapse.\n        delta_t: Time step in `ms`.\n        pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n        post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n        params: Parameters of the synapse. Conductances in `uS`.\n\n    Returns:\n        Updated states.\"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/mechanisms/#ionotropic-synapse","title":"Ionotropic Synapse","text":"

Bases: Synapse

Compute synaptic current and update synapse state for a generic ionotropic synapse.

The synapse state \u201cs\u201d is the probability that a postsynaptic receptor channel is open, and this depends on the amount of neurotransmitter released, which is in turn dependent on the presynaptic voltage.

The synaptic parameters are
  • gS: the maximal conductance across the postsynaptic membrane (uS)
  • e_syn: the reversal potential across the postsynaptic membrane (mV)
  • k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic receptor (s^-1)
Details of this implementation can be found in the following book chapter

L. F. Abbott and E. Marder, \u201cModeling Small Networks,\u201d in Methods in Neuronal Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.

Source code in jaxley/synapses/ionotropic.py
class IonotropicSynapse(Synapse):\n    \"\"\"\n    Compute synaptic current and update synapse state for a generic ionotropic synapse.\n\n    The synapse state \"s\" is the probability that a postsynaptic receptor channel is\n    open, and this depends on the amount of neurotransmitter released, which is in turn\n    dependent on the presynaptic voltage.\n\n    The synaptic parameters are:\n        - gS: the maximal conductance across the postsynaptic membrane (uS)\n        - e_syn: the reversal potential across the postsynaptic membrane (mV)\n        - k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic\n            receptor (s^-1)\n\n    Details of this implementation can be found in the following book chapter:\n        L. F. Abbott and E. Marder, \"Modeling Small Networks,\" in Methods in Neuronal\n        Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.\n\n    \"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.synapse_params = {\n            f\"{prefix}_gS\": 1e-4,\n            f\"{prefix}_e_syn\": 0.0,\n            f\"{prefix}_k_minus\": 0.025,\n        }\n        self.synapse_states = {f\"{prefix}_s\": 0.2}\n\n    def update_states(\n        self,\n        states: Dict,\n        delta_t: float,\n        pre_voltage: float,\n        post_voltage: float,\n        params: Dict,\n    ) -> Dict:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        prefix = self._name\n        v_th = -35.0  # mV\n        delta = 10.0  # mV\n\n        s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n        tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n        slope = -1.0 / tau_s\n        exp_term = save_exp(slope * delta_t)\n        new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n        return {f\"{prefix}_s\": new_s}\n\n    def compute_current(\n        self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n    ) -> float:\n        prefix = self._name\n        g_syn = params[f\"{prefix}_gS\"] * states[f\"{prefix}_s\"]\n        return g_syn * (post_voltage - params[f\"{prefix}_e_syn\"])\n
"},{"location":"reference/mechanisms/#jaxley.synapses.ionotropic.IonotropicSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

Return updated synapse state and current.

Source code in jaxley/synapses/ionotropic.py
def update_states(\n    self,\n    states: Dict,\n    delta_t: float,\n    pre_voltage: float,\n    post_voltage: float,\n    params: Dict,\n) -> Dict:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    prefix = self._name\n    v_th = -35.0  # mV\n    delta = 10.0  # mV\n\n    s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n    tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n    slope = -1.0 / tau_s\n    exp_term = save_exp(slope * delta_t)\n    new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n    return {f\"{prefix}_s\": new_s}\n
"},{"location":"reference/mechanisms/#tanh-rate-synapse","title":"TanH Rate Synapse","text":"

Bases: Synapse

Compute synaptic current for tanh synapse (no state).

Source code in jaxley/synapses/tanh_rate.py
class TanhRateSynapse(Synapse):\n    \"\"\"\n    Compute synaptic current for tanh synapse (no state).\n    \"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.synapse_params = {\n            f\"{prefix}_gS\": 1e-4,\n            f\"{prefix}_x_offset\": -70.0,\n            f\"{prefix}_slope\": 1.0,\n        }\n        self.synapse_states = {}\n\n    def update_states(\n        self,\n        states: Dict,\n        delta_t: float,\n        pre_voltage: float,\n        post_voltage: float,\n        params: Dict,\n    ) -> Dict:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        return {}\n\n    def compute_current(\n        self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n    ) -> float:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        prefix = self._name\n        current = (\n            -1\n            * params[f\"{prefix}_gS\"]\n            * jnp.tanh(\n                (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n            )\n        )\n        return current\n
"},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)","text":"

Return updated synapse state and current.

Source code in jaxley/synapses/tanh_rate.py
def compute_current(\n    self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n) -> float:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    prefix = self._name\n    current = (\n        -1\n        * params[f\"{prefix}_gS\"]\n        * jnp.tanh(\n            (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n        )\n    )\n    return current\n
"},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

Return updated synapse state and current.

Source code in jaxley/synapses/tanh_rate.py
def update_states(\n    self,\n    states: Dict,\n    delta_t: float,\n    pre_voltage: float,\n    post_voltage: float,\n    params: Dict,\n) -> Dict:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    return {}\n
"},{"location":"reference/modules/","title":"Modules","text":""},{"location":"reference/modules/#module","title":"Module","text":"

Bases: ABC

Module base class.

Modules are everything that can be passed to jx.integrate, i.e. compartments, branches, cells, and networks.

This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks).

Modules can be traversed and modified using the at, cell, branch, comp, edge, and loc methods. The scope method can be used to toggle between global and local indices. Traversal of Modules will return a View of itself, that has a modified set of attributes, which only consider the part of the Module that is in view.

For developers: The above has consequences for how to operate on Module and which changes take affect where. The following guidelines should be followed (copied from View):

  1. We consider a Module to have everything in view.
  2. Views can display and keep track of how a module is traversed. But(!), do not support making changes or setting variables. This still has to be done in the base Module, i.e. self.base. In order to enssure that these changes only affects whatever is currently in view self._nodes_in_view, or self._edges_in_view among others have to be used. Operating on nodes currently in view can for example be done with self.base.node.loc[self._nodes_in_view].
  3. Every attribute of Module that changes based on what\u2019s in view, i.e. xyzr, needs to modified when View is instantiated. I.e. xyzr of cell.branch(0), should be [self.base.xyzr[0]] This could be achieved via: [self.base.xyzr[b] for b in self._branches_in_view].

For developers: If you want to add a new method to Module, here is an example of how to make methods of Module compatible with View:

.. code-block:: python

# Use data in view to return something.\ndef count_small_branches(self):\n    # no need to use self.base.attr + viewed indices,\n    # since no change is made to the attr in question (nodes)\n    comp_lens = self.nodes[\"length\"]\n    branch_lens = comp_lens.groupby(\"global_branch_index\").sum()\n    return np.sum(branch_lens < 10)\n\n# Change data in view.\ndef change_attr_in_view(self):\n    # changes to attrs have to be made via self.base.attr + viewed indices\n    a = func1(self.base.attr1[self._cells_in_view])\n    b = func2(self.base.attr2[self._edges_in_view])\n    self.base.attr3[self._branches_in_view] = a + b\n
Source code in jaxley/modules/base.py
class Module(ABC):\n    \"\"\"Module base class.\n\n    Modules are everything that can be passed to `jx.integrate`, i.e. compartments,\n    branches, cells, and networks.\n\n    This base class defines the scaffold for all jaxley modules (compartments,\n    branches, cells, networks).\n\n    Modules can be traversed and modified using the `at`, `cell`, `branch`, `comp`,\n    `edge`, and `loc` methods. The `scope` method can be used to toggle between\n    global and local indices. Traversal of Modules will return a `View` of itself,\n    that has a modified set of attributes, which only consider the part of the Module\n    that is in view.\n\n    For developers: The above has consequences for how to operate on `Module` and which\n    changes take affect where. The following guidelines should be followed (copied from\n    `View`):\n\n    1. We consider a Module to have everything in view.\n    2. Views can display and keep track of how a module is traversed. But(!),\n       do not support making changes or setting variables. This still has to be\n       done in the base Module, i.e. `self.base`. In order to enssure that these\n       changes only affects whatever is currently in view `self._nodes_in_view`,\n       or `self._edges_in_view` among others have to be used. Operating on nodes\n       currently in view can for example be done with\n       `self.base.node.loc[self._nodes_in_view]`.\n    3. Every attribute of Module that changes based on what's in view, i.e. `xyzr`,\n       needs to modified when View is instantiated. I.e. `xyzr` of `cell.branch(0)`,\n       should be `[self.base.xyzr[0]]` This could be achieved via:\n       `[self.base.xyzr[b] for b in self._branches_in_view]`.\n\n    For developers: If you want to add a new method to `Module`, here is an example of\n    how to make methods of Module compatible with View:\n\n    .. code-block:: python\n\n        # Use data in view to return something.\n        def count_small_branches(self):\n            # no need to use self.base.attr + viewed indices,\n            # since no change is made to the attr in question (nodes)\n            comp_lens = self.nodes[\"length\"]\n            branch_lens = comp_lens.groupby(\"global_branch_index\").sum()\n            return np.sum(branch_lens < 10)\n\n        # Change data in view.\n        def change_attr_in_view(self):\n            # changes to attrs have to be made via self.base.attr + viewed indices\n            a = func1(self.base.attr1[self._cells_in_view])\n            b = func2(self.base.attr2[self._edges_in_view])\n            self.base.attr3[self._branches_in_view] = a + b\n    \"\"\"\n\n    def __init__(self):\n        self.ncomp: int = None\n        self.total_nbranches: int = 0\n        self.nbranches_per_cell: List[int] = None\n\n        self.groups = {}\n\n        self.nodes: Optional[pd.DataFrame] = None\n        self._scope = \"local\"  # defaults to local scope\n        self._nodes_in_view: np.ndarray = None\n        self._edges_in_view: np.ndarray = None\n\n        self.edges = pd.DataFrame(\n            columns=[\n                \"global_edge_index\",\n                \"pre_global_comp_index\",\n                \"post_global_comp_index\",\n                \"pre_locs\",\n                \"post_locs\",\n                \"type\",\n                \"type_ind\",\n            ]\n        )\n\n        self._cumsum_nbranches: Optional[np.ndarray] = None\n\n        self.comb_parents: jnp.ndarray = jnp.asarray([-1])\n\n        self.initialized_morph: bool = False\n        self.initialized_syns: bool = False\n\n        # List of all types of `jx.Synapse`s.\n        self.synapses: List = []\n        self.synapse_param_names = []\n        self.synapse_state_names = []\n        self.synapse_names = []\n        self.synapse_current_names: List[str] = []\n\n        # List of types of all `jx.Channel`s.\n        self.channels: List[Channel] = []\n        self.membrane_current_names: List[str] = []\n\n        # For trainable parameters.\n        self.indices_set_by_trainables: List[jnp.ndarray] = []\n        self.trainable_params: List[Dict[str, jnp.ndarray]] = []\n        self.allow_make_trainable: bool = True\n        self.num_trainable_params: int = 0\n\n        # For recordings.\n        self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})\n\n        # For stimuli or clamps.\n        # E.g. `self.externals = {\"v\": zeros(1000,2), \"i\": ones(1000, 2)}`\n        # for 1000 timesteps and two compartments.\n        self.externals: Dict[str, jnp.ndarray] = {}\n        # E.g. `self.external)inds = {\"v\": jnp.asarray([0,1]), \"i\": jnp.asarray([2,3])}`\n        self.external_inds: Dict[str, jnp.ndarray] = {}\n\n        # x, y, z coordinates and radius.\n        self.xyzr: List[np.ndarray] = []\n        self._radius_generating_fns = None  # Defined by `.read_swc()`.\n\n        # For debugging the solver. Will be empty by default and only filled if\n        # `self._init_morph_for_debugging` is run.\n        self.debug_states = {}\n\n        # needs to be set at the end\n        self.base: Module = self\n\n    def __repr__(self):\n        return f\"{type(self).__name__} with {len(self.channels)} different channels. Use `.nodes` for details.\"\n\n    def __str__(self):\n        return f\"jx.{type(self).__name__}\"\n\n    def __dir__(self):\n        base_dir = object.__dir__(self)\n        return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys()))\n\n    def __getattr__(self, key):\n        # Ensure that hidden methods such as `__deepcopy__` still work.\n        if key.startswith(\"__\"):\n            return super().__getattribute__(key)\n\n        # intercepts calls to groups\n        if key in self.base.groups:\n            view = (\n                self.select(self.groups[key])\n                if key in self.groups\n                else self.select(None)\n            )\n            view._set_controlled_by_param(key)\n            return view\n\n        # intercepts calls to channels\n        if key in [c._name for c in self.base.channels]:\n            channel_names = [c._name for c in self.channels]\n            inds = self.nodes.index[self.nodes[key]].to_numpy()\n            view = self.select(inds) if key in channel_names else self.select(None)\n            view._set_controlled_by_param(key)\n            return view\n\n        # intercepts calls to synapse types\n        if key in self.base.synapse_names:\n            syn_inds = self.edges[self.edges[\"type\"] == key][\n                \"global_edge_index\"\n            ].to_numpy()\n            orig_scope = self._scope\n            view = (\n                self.scope(\"global\").edge(syn_inds).scope(orig_scope)\n                if key in self.synapse_names\n                else self.select(None)\n            )\n            view._set_controlled_by_param(key)  # overwrites param set by edge\n            # Ensure synapse param sharing works with `edge`\n            # `edge` will be removed as part of #463\n            view.edges[\"local_edge_index\"] = np.arange(len(view.edges))\n            return view\n\n    def _childviews(self) -> List[str]:\n        \"\"\"Returns levels that module can be viewed at.\n\n        I.e. for net -> [cell, branch, comp]. For branch -> [comp]\"\"\"\n        levels = [\"network\", \"cell\", \"branch\", \"comp\"]\n        if self._current_view in levels:\n            children = levels[levels.index(self._current_view) + 1 :]\n            return children\n        return []\n\n    def _has_childview(self, key: str) -> bool:\n        child_views = self._childviews()\n        return key in child_views\n\n    def __getitem__(self, index):\n        \"\"\"Lazy indexing of the module.\"\"\"\n        supported_parents = [\"network\", \"cell\", \"branch\"]  # cannot index into comp\n\n        not_group_view = self._current_view not in self.groups\n        assert (\n            self._current_view in supported_parents or not_group_view\n        ), \"Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof.\"\n        index = index if isinstance(index, tuple) else (index,)\n\n        child_views = self._childviews()\n        assert len(index) <= len(child_views), \"Too many indices.\"\n        view = self\n        for i, child in zip(index, child_views):\n            view = view._at_nodes(child, i)\n        return view\n\n    def _update_local_indices(self) -> pd.DataFrame:\n        \"\"\"Compute local indices from the global indices that are in view.\n        This is recomputed everytime a View is created.\"\"\"\n        rerank = lambda df: df.rank(method=\"dense\").astype(int) - 1\n\n        def reorder_cols(\n            df: pd.DataFrame, cols: List[str], first: bool = True\n        ) -> pd.DataFrame:\n            \"\"\"Move cols to front/back.\n\n            Args:\n                df: DataFrame to reorder.\n                cols: List of columns to place before/after remaining columns.\n                first: If True, cols are placed in front, otherwise at the end.\n\n            Returns:\n                DataFrame with reordered columns.\"\"\"\n            new_cols = [col for col in df.columns if first == (col in cols)]\n            new_cols += [col for col in df.columns if first != (col in cols)]\n            return df[new_cols]\n\n        def reindex_a_by_b(\n            df: pd.DataFrame, a: str, b: Optional[Union[str, List[str]]] = None\n        ) -> pd.DataFrame:\n            \"\"\"Reindex based on a different col or several columns\n            for b=[0,0,1,1,2,2,2] -> a=[0,1,0,1,0,1,2]\"\"\"\n            grouped_df = df.groupby(b) if b is not None else df\n            df.loc[:, a] = rerank(grouped_df[a])\n            return df\n\n        index_names = [\"cell_index\", \"branch_index\", \"comp_index\"]  # order is important\n        global_idx_cols = [f\"global_{name}\" for name in index_names]\n        local_idx_cols = [f\"local_{name}\" for name in index_names]\n        idcs = self.nodes[global_idx_cols]\n\n        # update local indices of nodes\n        idcs = reindex_a_by_b(idcs, global_idx_cols[0])\n        idcs = reindex_a_by_b(idcs, global_idx_cols[1], global_idx_cols[0])\n        idcs = reindex_a_by_b(idcs, global_idx_cols[2], global_idx_cols[:2])\n        idcs.columns = [col.replace(\"global\", \"local\") for col in global_idx_cols]\n        self.nodes[local_idx_cols] = idcs[local_idx_cols].astype(int)\n\n        # move indices to the front of the dataframe; move controlled_by_param to the end\n        # move indices of current scope to the front and the others to the back\n        not_scope = \"global\" if self._scope == \"local\" else \"local\"\n        self.nodes = reorder_cols(\n            self.nodes, [f\"{self._scope}_{name}\" for name in index_names], first=True\n        )\n        self.nodes = reorder_cols(\n            self.nodes, [f\"{not_scope}_{name}\" for name in index_names], first=False\n        )\n\n        self.edges = reorder_cols(self.edges, [\"global_edge_index\"])\n        self.nodes = reorder_cols(self.nodes, [\"controlled_by_param\"], first=False)\n        self.edges = reorder_cols(self.edges, [\"controlled_by_param\"], first=False)\n\n    def _init_view(self):\n        \"\"\"Init attributes critical for View.\n\n        Needs to be called at init of a Module.\"\"\"\n        parent = self.__class__.__name__.lower()\n        self._current_view = \"comp\" if parent == \"compartment\" else parent\n        self._nodes_in_view = self.nodes.index.to_numpy()\n        self._edges_in_view = self.edges.index.to_numpy()\n        self.nodes[\"controlled_by_param\"] = 0\n\n    def _compute_coords_of_comp_centers(self) -> np.ndarray:\n        \"\"\"Compute xyz coordinates of compartment centers.\n\n        Centers are the midpoint between the comparment endpoints on the morphology\n        as defined by xyzr.\n\n        Note: For sake of performance, interpolation is not done for each branch\n        individually, but only once along a concatenated (and padded) array of all branches.\n        This means for ncomps = [2,4] and normalized cum_branch_lens of [[0,1],[0,1]] we would\n        interpolate xyz at the locations comp_ends = [[0,0.5,1], [0,0.25,0.5,0.75,1]],\n        where 0 is the start of the branch and 1 is the end point at the full branch_len.\n        To avoid do this in one go we set comp_ends = [0,0.5,1,2,2.25,2.5,2.75,3], and\n        norm_cum_branch_len = [0,1,2,3] incrememting and also padding them by 1 to\n        avoid overlapping branch_lens i.e. norm_cum_branch_len = [0,1,1,2] for only\n        incrementing.\n        \"\"\"\n        nodes_by_branches = self.nodes.groupby(\"global_branch_index\")\n        ncomps = nodes_by_branches[\"global_comp_index\"].nunique().to_numpy()\n\n        comp_ends = [\n            np.linspace(0, 1, ncomp + 1) + 2 * i for i, ncomp in enumerate(ncomps)\n        ]\n        comp_ends = np.hstack(comp_ends)\n\n        comp_ends = comp_ends.reshape(-1)\n        cum_branch_lens = []\n        for i, xyzr in enumerate(self.xyzr):\n            branch_len = np.sqrt(np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1))\n            cum_branch_len = np.cumsum(np.concatenate([np.array([0]), branch_len]))\n            max_len = cum_branch_len.max()\n            # add padding like above\n            cum_branch_len = cum_branch_len / (max_len if max_len > 0 else 1) + 2 * i\n            cum_branch_len[np.isnan(cum_branch_len)] = 0\n            cum_branch_lens.append(cum_branch_len)\n        cum_branch_lens = np.hstack(cum_branch_lens)\n        xyz = np.vstack(self.xyzr)[:, :3]\n        xyz = v_interp(comp_ends, cum_branch_lens, xyz).T\n        centers = (xyz[:-1] + xyz[1:]) / 2  # unaware of inter vs intra comp centers\n        cum_ncomps = np.cumsum(ncomps)\n        # this means centers between comps have to be removed here\n        between_comp_inds = (cum_ncomps + np.arange(len(cum_ncomps)))[:-1]\n        centers = np.delete(centers, between_comp_inds, axis=0)\n        return centers\n\n    def compute_compartment_centers(self):\n        \"\"\"Add compartment centers to nodes dataframe\"\"\"\n        centers = self._compute_coords_of_comp_centers()\n        self.base.nodes.loc[self._nodes_in_view, [\"x\", \"y\", \"z\"]] = centers\n\n    def _reformat_index(self, idx: Any, dtype: type = int) -> np.ndarray:\n        \"\"\"Transforms different types of indices into an array.\n\n        Takes slice, list, array, ints, range and None and transforms\n        it into array of indices. If index == \"all\" it returns \"all\"\n        to be handled downstream.\n\n        Args:\n            idx: index that specifies at which locations to view the module.\n            dtype: defaults to int, but can also reformat float for use in `loc`\n\n        Returns:\n            array of indices of shape (N,)\"\"\"\n        if is_str_all(idx):  # also asserts that the only allowed str == \"all\"\n            return idx\n\n        np_dtype = np.int64 if dtype is int else np.float64\n        idx = np.array([], dtype=dtype) if idx is None else idx\n        idx = np.array([idx]) if isinstance(idx, (dtype, np_dtype)) else idx\n        idx = np.array(idx) if isinstance(idx, (list, range, pd.Index)) else idx\n\n        idx = np.arange(len(self.base.nodes))[idx] if isinstance(idx, slice) else idx\n        if idx.dtype == bool:\n            shape = (*self.shape, len(self.edges))\n            which_idx = len(idx) == np.array(shape)\n            assert np.any(which_idx), \"Index not matching num of cells/branches/comps.\"\n            dim = shape[np.where(which_idx)[0][0]]\n            idx = np.arange(dim)[idx]\n        assert isinstance(idx, np.ndarray), \"Invalid type\"\n        assert idx.dtype in [np_dtype, bool], \"Invalid dtype\"\n        return idx.reshape(-1)\n\n    def _set_controlled_by_param(self, key: str):\n        \"\"\"Determines which parameters are shared in `make_trainable`.\n\n        Adds column to nodes/edges dataframes to read of shared params from.\n\n        Args:\n            key: key specifying group / view that is in control of the params.\"\"\"\n        if key in [\"comp\", \"branch\", \"cell\"]:\n            self.nodes[\"controlled_by_param\"] = self.nodes[f\"global_{key}_index\"]\n            self.edges[\"controlled_by_param\"] = 0\n        elif key == \"edge\":\n            self.edges[\"controlled_by_param\"] = np.arange(len(self.edges))\n        elif key == \"filter\":\n            self.nodes[\"controlled_by_param\"] = np.arange(len(self.nodes))\n            self.edges[\"controlled_by_param\"] = np.arange(len(self.edges))\n        else:\n            self.nodes[\"controlled_by_param\"] = 0\n            self.edges[\"controlled_by_param\"] = 0\n        self._current_view = key\n\n    def select(\n        self, nodes: np.ndarray = None, edges: np.ndarray = None, sorted: bool = False\n    ) -> View:\n        \"\"\"Return View of the module filtered by specific node or edges indices.\n\n        Args:\n            nodes: indices of nodes to view. If None, all nodes are viewed.\n            edges: indices of edges to view. If None, all edges are viewed.\n            sorted: if True, nodes and edges are sorted.\n\n        Returns:\n            View for subset of selected nodes and/or edges.\"\"\"\n\n        nodes = self._reformat_index(nodes) if nodes is not None else None\n        nodes = self._nodes_in_view if is_str_all(nodes) else nodes\n        nodes = np.sort(nodes) if sorted else nodes\n\n        edges = self._reformat_index(edges) if edges is not None else None\n        edges = self._edges_in_view if is_str_all(edges) else edges\n        edges = np.sort(edges) if sorted else edges\n\n        view = View(self, nodes, edges)\n        view._set_controlled_by_param(\"filter\")\n        return view\n\n    def set_scope(self, scope: str):\n        \"\"\"Toggle between \"global\" or \"local\" scope.\n\n        Determines if global or local indices are used for viewing the module.\n\n        Args:\n            scope: either \"global\" or \"local\".\"\"\"\n        assert scope in [\"global\", \"local\"], \"Invalid scope.\"\n        self._scope = scope\n\n    def scope(self, scope: str) -> View:\n        \"\"\"Return a View of the module with the specified scope.\n\n        For example `cell.scope(\"global\").branch(2).scope(\"local\").comp(1)`\n        will return the 1st compartment of branch 2.\n\n        Args:\n            scope: either \"global\" or \"local\".\n\n        Returns:\n            View with the specified scope.\"\"\"\n        view = self.view\n        view.set_scope(scope)\n        return view\n\n    def _at_nodes(self, key: str, idx: Any) -> View:\n        \"\"\"Return a View of the module filtering `nodes` by specified key and index.\n\n        Keys can be `cell`, `branch`, `comp` and determine which index is used to filter.\n        \"\"\"\n        base_name = self.base.__class__.__name__\n        assert self.base._has_childview(key), f\"{base_name} does not support {key}.\"\n        idx = self._reformat_index(idx)\n        idx = self.nodes[self._scope + f\"_{key}_index\"] if is_str_all(idx) else idx\n        where = self.nodes[self._scope + f\"_{key}_index\"].isin(idx)\n        inds = self.nodes.index[where].to_numpy()\n\n        view = View(self, nodes=inds)\n        view._set_controlled_by_param(key)\n        return view\n\n    def _at_edges(self, key: str, idx: Any) -> View:\n        \"\"\"Return a View of the module filtering `edges` by specified key and index.\n\n        Keys can be `pre`, `post`, `edge` and determine which index is used to filter.\n        \"\"\"\n        idx = self._reformat_index(idx)\n        idx = self.edges[self._scope + f\"_{key}_index\"] if is_str_all(idx) else idx\n        where = self.edges[self._scope + f\"_{key}_index\"].isin(idx)\n        inds = self.edges.index[where].to_numpy()\n\n        view = View(self, edges=inds)\n        view._set_controlled_by_param(key)\n        return view\n\n    def cell(self, idx: Any) -> View:\n        \"\"\"Return a View of the module at the selected cell(s).\n\n        Args:\n            idx: index of the cell to view.\n\n        Returns:\n            View of the module at the specified cell index.\"\"\"\n        return self._at_nodes(\"cell\", idx)\n\n    def branch(self, idx: Any) -> View:\n        \"\"\"Return a View of the module at the selected branches(s).\n\n        Args:\n            idx: index of the branch to view.\n\n        Returns:\n            View of the module at the specified branch index.\"\"\"\n        return self._at_nodes(\"branch\", idx)\n\n    def comp(self, idx: Any) -> View:\n        \"\"\"Return a View of the module at the selected compartments(s).\n\n        Args:\n            idx: index of the comp to view.\n\n        Returns:\n            View of the module at the specified compartment index.\"\"\"\n        return self._at_nodes(\"comp\", idx)\n\n    def edge(self, idx: Any) -> View:\n        \"\"\"Return a View of the module at the selected synapse edges(s).\n\n        Args:\n            idx: index of the edge to view.\n\n        Returns:\n            View of the module at the specified edge index.\"\"\"\n        return self._at_edges(\"edge\", idx)\n\n    def loc(self, at: Any) -> View:\n        \"\"\"Return a View of the module at the selected branch location(s).\n\n        Args:\n            at: location along the branch.\n\n        Returns:\n            View of the module at the specified branch location.\"\"\"\n        global_comp_idxs = []\n        for i in self._branches_in_view:\n            ncomp = self.base.ncomp_per_branch[i]\n            comp_locs = np.linspace(0, 1, ncomp)\n            at = comp_locs if is_str_all(at) else self._reformat_index(at, dtype=float)\n            comp_edges = np.linspace(0, 1 + 1e-10, ncomp + 1)\n            idx = np.digitize(at, comp_edges) - 1 + self.base.cumsum_ncomp[i]\n            global_comp_idxs.append(idx)\n        global_comp_idxs = np.concatenate(global_comp_idxs)\n        orig_scope = self._scope\n        # global scope needed to select correct comps, for i.e. branches w. ncomp=[1,2]\n        # loc(0.9)  will correspond to different local branches (0 vs 1).\n        view = self.scope(\"global\").comp(global_comp_idxs).scope(orig_scope)\n        view._current_view = \"loc\"\n        return view\n\n    @property\n    def _comps_in_view(self):\n        \"\"\"Lists the global compartment indices which are currently part of the view.\"\"\"\n        # method also exists in View. this copy forgoes need to instantiate a View\n        return self.nodes[\"global_comp_index\"].unique()\n\n    @property\n    def _branches_in_view(self):\n        \"\"\"Lists the global branch indices which are currently part of the view.\"\"\"\n        # method also exists in View. this copy forgoes need to instantiate a View\n        return self.nodes[\"global_branch_index\"].unique()\n\n    @property\n    def _cells_in_view(self):\n        \"\"\"Lists the global cell indices which are currently part of the view.\"\"\"\n        # method also exists in View. this copy forgoes need to instantiate a View\n        return self.nodes[\"global_cell_index\"].unique()\n\n    def _iter_submodules(self, name: str):\n        \"\"\"Iterate over submoduleslevel.\n\n        Used for `cells`, `branches`, `comps`.\"\"\"\n        col = self._scope + f\"_{name}_index\"\n        idxs = self.nodes[col].unique()\n        for idx in idxs:\n            yield self._at_nodes(name, idx)\n\n    @property\n    def cells(self):\n        \"\"\"Iterate over all cells in the module.\n\n        Returns a generator that yields a View of each cell.\"\"\"\n        yield from self._iter_submodules(\"cell\")\n\n    @property\n    def branches(self):\n        \"\"\"Iterate over all branches in the module.\n\n        Returns a generator that yields a View of each branch.\"\"\"\n        yield from self._iter_submodules(\"branch\")\n\n    @property\n    def comps(self):\n        \"\"\"Iterate over all compartments in the module.\n        Can be called on any module, i.e. `net.comps`, `cell.comps` or\n        `branch.comps`. `__iter__` does not allow for this.\n\n        Returns a generator that yields a View of each compartment.\"\"\"\n        yield from self._iter_submodules(\"comp\")\n\n    def __iter__(self):\n        \"\"\"Iterate over parts of the module.\n\n        Internally calls `cells`, `branches`, `comps` at the appropriate level.\n\n        Example:\n\n        .. code-block:: python\n\n            for cell in network:\n                for branch in cell:\n                    for comp in branch:\n                        print(comp.nodes.shape)\n        \"\"\"\n        next_level = self._childviews()[0]\n        yield from self._iter_submodules(next_level)\n\n    @property\n    def shape(self) -> Tuple[int]:\n        \"\"\"Returns the number of submodules contained in a module.\n\n        .. code-block:: python\n\n            network.shape = (num_cells, num_branches, num_compartments)\n            cell.shape = (num_branches, num_compartments)\n            branch.shape = (num_compartments,)\n        \"\"\"\n        cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n        raw_shape = self.nodes[cols].nunique().to_list()\n\n        # ensure (net.shape -> dim=3, cell.shape -> dim=2, branch.shape -> dim=1, comp.shape -> dim=0)\n        levels = [\"network\", \"cell\", \"branch\", \"comp\"]\n        module = self.base.__class__.__name__.lower()\n        module = \"comp\" if module == \"compartment\" else module\n        shape = tuple(raw_shape[levels.index(module) :])\n        return shape\n\n    def copy(\n        self, reset_index: bool = False, as_module: bool = False\n    ) -> Union[Module, View]:\n        \"\"\"Extract part of a module and return a copy of its View or a new module.\n\n        This can be used to call `jx.integrate` on part of a Module.\n\n        Args:\n            reset_index: if True, the indices of the new module are reset to start from 0.\n            as_module: if True, a new module is returned instead of a View.\n\n        Returns:\n            A part of the module or a copied view of it.\"\"\"\n        view = deepcopy(self)\n        warnings.warn(\"This method is experimental, use at your own risk.\")\n        # TODO FROM #447: add reset_index, i.e. for parents, nodes, edges etc. such that they\n        # start from 0/-1 and are contiguous\n        if as_module:\n            raise NotImplementedError(\"Not yet implemented.\")\n            # initialize a new module with the same attributes\n        return view\n\n    @property\n    def view(self):\n        \"\"\"Return view of the module.\"\"\"\n        return View(self, self._nodes_in_view, self._edges_in_view)\n\n    @property\n    def _module_type(self):\n        \"\"\"Return type of the module (compartment, branch, cell, network) as string.\n\n        This is used to perform asserts for some modules (e.g. network cannot use\n        `set_ncomp`) without having to import the module in `base.py`.\"\"\"\n        return self.__class__.__name__.lower()\n\n    def _append_params_and_states(self, param_dict: Dict, state_dict: Dict):\n        \"\"\"Insert the default params of the module (e.g. radius, length).\n\n        This is run at `__init__()`. It does not deal with channels.\n        \"\"\"\n        for param_name, param_value in param_dict.items():\n            self.base.nodes[param_name] = param_value\n        for state_name, state_value in state_dict.items():\n            self.base.nodes[state_name] = state_value\n\n    def _gather_channels_from_constituents(self, constituents: List):\n        \"\"\"Modify `self.channels` and `self.nodes` with channel info from constituents.\n\n        This is run at `__init__()`. It takes all branches of constituents (e.g.\n        of all branches when the are assembled into a cell) and adds columns to\n        `.nodes` for the relevant channels.\n        \"\"\"\n        for module in constituents:\n            for channel in module.channels:\n                if channel._name not in [c._name for c in self.channels]:\n                    self.base.channels.append(channel)\n                if channel.current_name not in self.membrane_current_names:\n                    self.base.membrane_current_names.append(channel.current_name)\n        # Setting columns of channel names to `False` instead of `NaN`.\n        for channel in self.base.channels:\n            name = channel._name\n            self.base.nodes.loc[self.nodes[name].isna(), name] = False\n\n    @only_allow_module\n    def to_jax(self):\n        # TODO FROM #447: Make this work for View?\n        \"\"\"Move `.nodes` to `.jaxnodes`.\n\n        Before the actual simulation is run (via `jx.integrate`), all parameters of\n        the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n        simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n        they can be processed on GPU/TPU and such that the simulation can be\n        differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n        \"\"\"\n        self.base.jaxnodes = {}\n        for key, value in self.base.nodes.to_dict(orient=\"list\").items():\n            inds = jnp.arange(len(value))\n            self.base.jaxnodes[key] = jnp.asarray(value)[inds]\n\n        # `jaxedges` contains only parameters (no indices).\n        # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n        # we allow parameter sharing.\n        self.base.jaxedges = {}\n        edges = self.base.edges.to_dict(orient=\"list\")\n        for i, synapse in enumerate(self.base.synapses):\n            condition = np.asarray(edges[\"type_ind\"]) == i\n            for key in synapse.synapse_params:\n                self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n            for key in synapse.synapse_states:\n                self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n\n    def show(\n        self,\n        param_names: Optional[Union[str, List[str]]] = None,\n        *,\n        indices: bool = True,\n        params: bool = True,\n        states: bool = True,\n        channel_names: Optional[List[str]] = None,\n    ) -> pd.DataFrame:\n        \"\"\"Print detailed information about the Module or a view of it.\n\n        Args:\n            param_names: The names of the parameters to show. If `None`, all parameters\n                are shown.\n            indices: Whether to show the indices of the compartments.\n            params: Whether to show the parameters of the compartments.\n            states: Whether to show the states of the compartments.\n            channel_names: The names of the channels to show. If `None`, all channels are\n                shown.\n\n        Returns:\n            A `pd.DataFrame` with the requested information.\n        \"\"\"\n        nodes = self.nodes.copy()  # prevents this from being edited\n\n        cols = []\n        inds = [\"comp_index\", \"branch_index\", \"cell_index\"]\n        scopes = [\"local\", \"global\"]\n        inds = [f\"{s}_{i}\" for i in inds for s in scopes] if indices else []\n        cols += inds\n        cols += [ch._name for ch in self.channels] if channel_names else []\n        cols += (\n            sum([list(ch.channel_params) for ch in self.channels], []) if params else []\n        )\n        cols += (\n            sum([list(ch.channel_states) for ch in self.channels], []) if states else []\n        )\n\n        if not param_names is None:\n            cols = (\n                inds + [c for c in cols if c in param_names]\n                if params\n                else list(param_names)\n            )\n\n        return nodes[cols]\n\n    @only_allow_module\n    def _init_morph(self):\n        \"\"\"Initialize the morphology such that it can be processed by the solvers.\"\"\"\n        self._init_morph_jaxley_spsolve()\n        self._init_morph_jax_spsolve()\n        self.initialized_morph = True\n\n    @abstractmethod\n    def _init_morph_jax_spsolve(self):\n        \"\"\"Initialize the morphology for the JAX sparse solver.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _init_morph_jaxley_spsolve(self):\n        \"\"\"Initialize the morphology for the custom Jaxley solver.\"\"\"\n        raise NotImplementedError\n\n    def _compute_axial_conductances(self, params: Dict[str, jnp.ndarray]):\n        \"\"\"Given radius, length, r_a, compute the axial coupling conductances.\"\"\"\n        return compute_axial_conductances(self._comp_edges, params)\n\n    def set(self, key: str, val: Union[float, jnp.ndarray]):\n        \"\"\"Set parameter of module (or its view) to a new value.\n\n        Note that this function can not be called within `jax.jit` or `jax.grad`.\n        Instead, it should be used set the parameters of the module **before** the\n        simulation. Use `.data_set()` to set parameters during `jax.jit` or\n        `jax.grad`.\n\n        Args:\n            key: The name of the parameter to set.\n            val: The value to set the parameter to. If it is `jnp.ndarray` then it\n                must be of shape `(len(num_compartments))`.\n        \"\"\"\n        if key in self.nodes.columns:\n            not_nan = ~self.nodes[key].isna().to_numpy()\n            self.base.nodes.loc[self._nodes_in_view[not_nan], key] = val\n        elif key in self.edges.columns:\n            not_nan = ~self.edges[key].isna().to_numpy()\n            self.base.edges.loc[self._edges_in_view[not_nan], key] = val\n        else:\n            raise KeyError(f\"Key '{key}' not found in nodes or edges\")\n\n    def data_set(\n        self,\n        key: str,\n        val: Union[float, jnp.ndarray],\n        param_state: Optional[List[Dict]],\n    ):\n        \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n        Args:\n            key: The name of the parameter to set.\n            val: The value to set the parameter to. If it is `jnp.ndarray` then it\n                must be of shape `(len(num_compartments))`.\n            param_state: State of the setted parameters, internally used such that this\n                function does not modify global state.\n        \"\"\"\n        # Note: `data_set` does not support arrays for `val`.\n        is_node_param = key in self.nodes.columns\n        data = self.nodes if is_node_param else self.edges\n        viewed_inds = self._nodes_in_view if is_node_param else self._edges_in_view\n        if key in data.columns:\n            not_nan = ~data[key].isna()\n            added_param_state = [\n                {\n                    \"indices\": np.atleast_2d(viewed_inds[not_nan]),\n                    \"key\": key,\n                    \"val\": jnp.atleast_1d(jnp.asarray(val)),\n                }\n            ]\n            if param_state is not None:\n                param_state += added_param_state\n            else:\n                param_state = added_param_state\n        else:\n            raise KeyError(\"Key not recognized.\")\n        return param_state\n\n    def set_ncomp(\n        self,\n        ncomp: int,\n        min_radius: Optional[float] = None,\n    ):\n        \"\"\"Set the number of compartments with which the branch is discretized.\n\n        Args:\n            ncomp: The number of compartments that the branch should be discretized\n                into.\n            min_radius: Only used if the morphology was read from an SWC file. If passed\n                the radius is capped to be at least this value.\n\n        Raises:\n            - When there are stimuli in any compartment in the module.\n            - When there are recordings in any compartment in the module.\n            - When the channels of the compartments are not the same within the branch\n            that is modified.\n            - When the lengths of the compartments are not the same within the branch\n            that is modified.\n            - Unless the morphology was read from an SWC file, when the radiuses of the\n            compartments are not the same within the branch that is modified.\n        \"\"\"\n        assert len(self.base.externals) == 0, \"No stimuli allowed!\"\n        assert len(self.base.recordings) == 0, \"No recordings allowed!\"\n        assert len(self.base.trainable_params) == 0, \"No trainables allowed!\"\n\n        assert self.base._module_type != \"network\", \"This is not allowed for networks.\"\n        assert not (\n            self.base._module_type == \"cell\"\n            and len(self._branches_in_view) == len(self.base._branches_in_view)\n        ), \"This is not allowed for cells.\"\n\n        # Update all attributes that are affected by compartment structure.\n        view = self.nodes.copy()\n        all_nodes = self.base.nodes\n        start_idx = self.nodes[\"global_comp_index\"].to_numpy()[0]\n        ncomp_per_branch = self.base.ncomp_per_branch\n        channel_names = [c._name for c in self.base.channels]\n        channel_param_names = list(\n            chain(*[c.channel_params for c in self.base.channels])\n        )\n        channel_state_names = list(\n            chain(*[c.channel_states for c in self.base.channels])\n        )\n        radius_generating_fns = self.base._radius_generating_fns\n\n        within_branch_radiuses = view[\"radius\"].to_numpy()\n        compartment_lengths = view[\"length\"].to_numpy()\n        num_previous_ncomp = len(within_branch_radiuses)\n        branch_indices = pd.unique(view[\"global_branch_index\"])\n\n        error_msg = lambda name: (\n            f\"You previously modified the {name} of individual compartments, but \"\n            f\"now you are modifying the number of compartments in this branch. \"\n            f\"This is not allowed. First build the morphology with `set_ncomp()` and \"\n            f\"then modify the radiuses and lengths of compartments.\"\n        )\n\n        if (\n            ~np.all(within_branch_radiuses == within_branch_radiuses[0])\n            and radius_generating_fns is None\n        ):\n            raise ValueError(error_msg(\"radius\"))\n\n        for property_name in [\"length\", \"capacitance\", \"axial_resistivity\"]:\n            compartment_properties = view[property_name].to_numpy()\n            if ~np.all(compartment_properties == compartment_properties[0]):\n                raise ValueError(error_msg(property_name))\n\n        if not (self.nodes[channel_names].var() == 0.0).all():\n            raise ValueError(\n                \"Some channel exists only in some compartments of the branch which you\"\n                \"are trying to modify. This is not allowed. First specify the number\"\n                \"of compartments with `.set_ncomp()` and then insert the channels\"\n                \"accordingly.\"\n            )\n\n        if not (\n            self.nodes[channel_param_names + channel_state_names].var() == 0.0\n        ).all():\n            raise ValueError(\n                \"Some channel has different parameters or states between the \"\n                \"different compartments of the branch which you are trying to modify. \"\n                \"This is not allowed. First specify the number of compartments with \"\n                \"`.set_ncomp()` and then insert the channels accordingly.\"\n            )\n\n        # Add new rows as the average of all rows. Special case for the length is below.\n        average_row = self.nodes.mean(skipna=False)\n        average_row = average_row.to_frame().T\n        view = pd.concat([*[average_row] * ncomp], axis=\"rows\")\n\n        # Set the correct datatype after having performed an average which cast\n        # everything to float.\n        integer_cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n        view[integer_cols] = view[integer_cols].astype(int)\n\n        # Whether or not a channel exists in a compartment is a boolean.\n        boolean_cols = channel_names\n        view[boolean_cols] = view[boolean_cols].astype(bool)\n\n        # Special treatment for the lengths and radiuses. These are not being set as\n        # the average because we:\n        # 1) Want to maintain the total length of a branch.\n        # 2) Want to use the SWC inferred radius.\n        #\n        # Compute new compartment lengths.\n        comp_lengths = np.sum(compartment_lengths) / ncomp\n        view[\"length\"] = comp_lengths\n\n        # Compute new compartment radiuses.\n        if radius_generating_fns is not None:\n            view[\"radius\"] = build_radiuses_from_xyzr(\n                radius_fns=radius_generating_fns,\n                branch_indices=branch_indices,\n                min_radius=min_radius,\n                ncomp=ncomp,\n            )\n        else:\n            view[\"radius\"] = within_branch_radiuses[0] * np.ones(ncomp)\n\n        # Update `.nodes`.\n        # 1) Delete N rows starting from start_idx\n        number_deleted = num_previous_ncomp\n        all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted))\n\n        # 2) Insert M new rows at the same location\n        df1 = all_nodes.iloc[:start_idx]  # Rows before the insertion point\n        df2 = all_nodes.iloc[start_idx:]  # Rows after the insertion point\n\n        # 3) Combine the parts: before, new rows, and after\n        all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True)\n\n        # Override `comp_index` to just be a consecutive list.\n        all_nodes[\"global_comp_index\"] = np.arange(len(all_nodes))\n\n        # Update compartment structure arguments.\n        ncomp_per_branch[branch_indices] = ncomp\n        ncomp = int(np.max(ncomp_per_branch))\n        cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n        internal_node_inds = np.arange(cumsum_ncomp[-1])\n\n        self.base.nodes = all_nodes\n        self.base.ncomp_per_branch = ncomp_per_branch\n        self.base.ncomp = ncomp\n        self.base.cumsum_ncomp = cumsum_ncomp\n        self.base._internal_node_inds = internal_node_inds\n\n        # Update the morphology indexing (e.g., `.comp_edges`).\n        self.base._initialize()\n        self.base._init_view()\n        self.base._update_local_indices()\n\n    def make_trainable(\n        self,\n        key: str,\n        init_val: Optional[Union[float, list]] = None,\n        verbose: bool = True,\n    ):\n        \"\"\"Make a parameter trainable.\n\n        If a parameter is made trainable, it will be returned by `get_parameters()`\n        and should then be passed to `jx.integrate(..., params=params)`.\n\n        Args:\n            key: Name of the parameter to make trainable.\n            init_val: Initial value of the parameter. If `float`, the same value is\n                used for every created parameter. If `list`, the length of the list has\n                to match the number of created parameters. If `None`, the current\n                parameter value is used and if parameter sharing is performed that the\n                current parameter value is averaged over all shared parameters.\n            verbose: Whether to print the number of parameters that are added and the\n                total number of parameters.\n        \"\"\"\n        assert (\n            self.allow_make_trainable\n        ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n        ncomps_per_branch = (\n            self.base.nodes[\"global_branch_index\"].value_counts().to_numpy()\n        )\n        assert np.all(\n            ncomps_per_branch == ncomps_per_branch[0]\n        ), \"Parameter sharing is not allowed for modules containing branches with different numbers of compartments.\"\n\n        data = self.nodes if key in self.nodes.columns else None\n        data = self.edges if key in self.edges.columns else data\n\n        assert data is not None, f\"Key '{key}' not found in nodes or edges\"\n        not_nan = ~data[key].isna()\n        data = data.loc[not_nan]\n        assert (\n            len(data) > 0\n        ), \"No settable parameters found in the selected compartments.\"\n\n        grouped_view = data.groupby(\"controlled_by_param\")\n        # Because of this `x.index.values` we cannot support `make_trainable()` on\n        # the module level for synapse parameters (but only for `SynapseView`).\n        inds_of_comps = list(\n            grouped_view.apply(lambda x: x.index.values, include_groups=False)\n        )\n        indices_per_param = jnp.stack(inds_of_comps)\n        # Sorted inds are only used to infer the correct starting values.\n        param_vals = jnp.asarray(\n            [data.loc[inds, key].to_numpy() for inds in inds_of_comps]\n        )\n\n        # Set the value which the trainable parameter should take.\n        num_created_parameters = len(indices_per_param)\n        if init_val is not None:\n            if isinstance(init_val, float):\n                new_params = jnp.asarray([init_val] * num_created_parameters)\n            elif isinstance(init_val, list):\n                assert (\n                    len(init_val) == num_created_parameters\n                ), f\"len(init_val)={len(init_val)}, but trying to create {num_created_parameters} parameters.\"\n                new_params = jnp.asarray(init_val)\n            else:\n                raise ValueError(\n                    f\"init_val must a float, list, or None, but it is a {type(init_val).__name__}.\"\n                )\n        else:\n            new_params = jnp.mean(param_vals, axis=1)\n        self.base.trainable_params.append({key: new_params})\n        self.base.indices_set_by_trainables.append(indices_per_param)\n        self.base.num_trainable_params += num_created_parameters\n        if verbose:\n            print(\n                f\"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}\"\n            )\n\n    def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):\n        \"\"\"Write the trainables into `.nodes` and `.edges`.\n\n        This allows to, e.g., visualize trained networks with `.vis()`.\n\n        Args:\n            trainable_params: The trainable parameters returned by `get_parameters()`.\n        \"\"\"\n        # We do not support views. Why? `jaxedges` does not have any NaN\n        # elements, whereas edges does. Because of this, we already need special\n        # treatment to make this function work, and it would be an even bigger hassle\n        # if we wanted to support this.\n        assert self.__class__.__name__ in [\n            \"Compartment\",\n            \"Branch\",\n            \"Cell\",\n            \"Network\",\n        ], \"Only supports modules.\"\n\n        # We could also implement this without casting the module to jax.\n        # However, I think it allows us to reuse as much code as possible and it avoids\n        # any kind of issues with indexing or parameter sharing (as this is fully\n        # taken care of by `get_all_parameters()`).\n        self.base.to_jax()\n        pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)\n        all_params = self.base.get_all_parameters(pstate, voltage_solver=\"jaxley.stone\")\n\n        # The value for `delta_t` does not matter here because it is only used to\n        # compute the initial current. However, the initial current cannot be made\n        # trainable and so its value never gets used below.\n        all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025)\n\n        # Loop only over the keys in `pstate` to avoid unnecessary computation.\n        for parameter in pstate:\n            key = parameter[\"key\"]\n            if key in self.base.nodes.columns:\n                vals_to_set = all_params if key in all_params.keys() else all_states\n                self.base.nodes[key] = vals_to_set[key]\n\n        # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n        # we allow parameter sharing.\n        edges = self.base.edges.to_dict(orient=\"list\")\n        for i, synapse in enumerate(self.base.synapses):\n            condition = np.asarray(edges[\"type_ind\"]) == i\n            for key in list(synapse.synapse_params.keys()):\n                self.base.edges.loc[condition, key] = all_params[key]\n            for key in list(synapse.synapse_states.keys()):\n                self.base.edges.loc[condition, key] = all_states[key]\n\n    def distance(self, endpoint: \"View\") -> float:\n        \"\"\"Return the direct distance between two compartments.\n        This does not compute the pathwise distance (which is currently not\n        implemented).\n        Args:\n            endpoint: The compartment to which to compute the distance to.\n        \"\"\"\n        assert len(self.xyzr) == 1 and len(endpoint.xyzr) == 1\n        start_xyz = np.mean(self.xyzr[0][:, :3], axis=0)\n        end_xyz = np.mean(endpoint.xyzr[0][:, :3], axis=0)\n        return np.sqrt(np.sum((start_xyz - end_xyz) ** 2))\n\n    def delete_trainables(self):\n        \"\"\"Removes all trainable parameters from the module.\"\"\"\n\n        if isinstance(self, View):\n            trainables_and_inds = self._filter_trainables(is_viewed=False)\n            self.base.indices_set_by_trainables = trainables_and_inds[0]\n            self.base.trainable_params = trainables_and_inds[1]\n            self.base.num_trainable_params -= self.num_trainable_params\n        else:\n            self.base.indices_set_by_trainables = []\n            self.base.trainable_params = []\n            self.base.num_trainable_params = 0\n        self._update_view()\n\n    def add_to_group(self, group_name: str):\n        \"\"\"Add a view of the module to a group.\n\n        Groups can then be indexed. For example:\n\n        .. code-block:: python\n\n            net.cell(0).add_to_group(\"excitatory\")\n            net.excitatory.set(\"radius\", 0.1)\n\n        Args:\n            group_name: The name of the group.\n        \"\"\"\n        if group_name not in self.base.groups:\n            self.base.groups[group_name] = self._nodes_in_view\n        else:\n            self.base.groups[group_name] = np.unique(\n                np.concatenate([self.base.groups[group_name], self._nodes_in_view])\n            )\n\n    def _get_state_names(self) -> Tuple[List, List]:\n        \"\"\"Collect all recordable / clampable states in the membrane and synapses.\n\n        Returns states seperated by comps and edges.\"\"\"\n        channel_states = [name for c in self.channels for name in c.channel_states]\n        synapse_states = [\n            name for s in self.synapses if s is not None for name in s.synapse_states\n        ]\n        membrane_states = [\"v\", \"i\"] + self.membrane_current_names\n        return (\n            channel_states + membrane_states,\n            synapse_states + self.synapse_current_names,\n        )\n\n    def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n        \"\"\"Get all trainable parameters.\n\n        The returned parameters should be passed to `jx.integrate(..., params=params).\n\n        Returns:\n            A list of all trainable parameters in the form of\n                [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n        \"\"\"\n        return self.trainable_params\n\n    @only_allow_module\n    def get_all_parameters(\n        self, pstate: List[Dict], voltage_solver: str\n    ) -> Dict[str, jnp.ndarray]:\n        # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n        \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n        Runs `_compute_axial_conductances()` and return every parameter that is needed\n        to solve the ODE. This includes conductances, radiuses, lengths,\n        axial_resistivities, but also coupling conductances.\n\n        This is done by first obtaining the current value of every parameter (not only\n        the trainable ones) and then replacing the trainable ones with the value\n        in `trainable_params()`. This function is run within `jx.integrate()`.\n\n        pstate can be obtained by calling `params_to_pstate()`.\n\n        .. code-block:: python\n\n            params = module.get_parameters() # i.e. [0, 1, 2]\n            pstate = params_to_pstate(params, module.indices_set_by_trainables)\n            module.to_jax() # needed for call to module.jaxnodes\n\n        Args:\n            pstate: The state of the trainable parameters. pstate takes the form\n                [{\n                    \"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]),\n                    \"val\": jnp.array([0.1, 0.2, 0.3])\n                }, ...].\n            voltage_solver: The voltage solver that is used. Since `jax.sparse` and\n                `jaxley.xyz` require different formats of the axial conductances, this\n                function will default to different building methods.\n\n        Returns:\n            A dictionary of all module parameters.\n        \"\"\"\n        params = {}\n        for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n            params[key] = self.base.jaxnodes[key]\n\n        for channel in self.base.channels:\n            for channel_params in channel.channel_params:\n                params[channel_params] = self.base.jaxnodes[channel_params]\n\n        for synapse_params in self.base.synapse_param_names:\n            params[synapse_params] = self.base.jaxedges[synapse_params]\n\n        # Override with those parameters set by `.make_trainable()`.\n        for parameter in pstate:\n            key = parameter[\"key\"]\n            inds = parameter[\"indices\"]\n            set_param = parameter[\"val\"]\n\n            # This is needed since SynapseViews worked differently before.\n            # This mimics the old behaviour and tranformes the new indices\n            # to the old indices.\n            # TODO FROM #447: Longterm this should be gotten rid of.\n            # Instead edges should work similar to nodes (would also allow for\n            # param sharing).\n            synapse_inds = self.base.edges.groupby(\"type\").rank()[\"global_edge_index\"]\n            synapse_inds = (synapse_inds.astype(int) - 1).to_numpy()\n            if key in self.base.synapse_param_names:\n                inds = synapse_inds[inds]\n\n            if key in params:  # Only parameters, not initial states.\n                # `inds` is of shape `(num_params, num_comps_per_param)`.\n                # `set_param` is of shape `(num_params,)`\n                # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n                # `.set()` to work. This is done with `[:, None]`.\n                params[key] = params[key].at[inds].set(set_param[:, None])\n\n        # Compute conductance params and add them to the params dictionary.\n        params[\"axial_conductances\"] = self.base._compute_axial_conductances(\n            params=params\n        )\n        return params\n\n    @only_allow_module\n    def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]:\n        # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n        \"\"\"Return states as they are set in the `.nodes` and `.edges` tables.\"\"\"\n        self.base.to_jax()  # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.\n        states = {\"v\": self.base.jaxnodes[\"v\"]}\n        # Join node and edge states into a single state dictionary.\n        for channel in self.base.channels:\n            for channel_states in channel.channel_states:\n                states[channel_states] = self.base.jaxnodes[channel_states]\n        for synapse_states in self.base.synapse_state_names:\n            states[synapse_states] = self.base.jaxedges[synapse_states]\n        return states\n\n    @only_allow_module\n    def get_all_states(\n        self, pstate: List[Dict], all_params, delta_t: float\n    ) -> Dict[str, jnp.ndarray]:\n        # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n        \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n        Args:\n            pstate: The state of the trainable parameters.\n            all_params: All parameters of the module.\n            delta_t: The time step.\n\n        Returns:\n            A dictionary of all states of the module.\n        \"\"\"\n        states = self.base._get_states_from_nodes_and_edges()\n\n        # Override with the initial states set by `.make_trainable()`.\n        for parameter in pstate:\n            key = parameter[\"key\"]\n            inds = parameter[\"indices\"]\n            set_param = parameter[\"val\"]\n            if key in states:  # Only initial states, not parameters.\n                # `inds` is of shape `(num_params, num_comps_per_param)`.\n                # `set_param` is of shape `(num_params,)`\n                # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n                # `.set()` to work. This is done with `[:, None]`.\n                states[key] = states[key].at[inds].set(set_param[:, None])\n\n        # Add to the states the initial current through every channel.\n        states, _ = self.base._channel_currents(\n            states, delta_t, self.channels, self.nodes, all_params\n        )\n\n        # Add to the states the initial current through every synapse.\n        states, _ = self.base._synapse_currents(\n            states, self.synapses, all_params, delta_t, self.edges\n        )\n        return states\n\n    @property\n    def initialized(self) -> bool:\n        \"\"\"Whether the `Module` is ready to be solved or not.\"\"\"\n        return self.initialized_morph\n\n    def _initialize(self):\n        \"\"\"Initialize the module.\"\"\"\n        self._init_morph()\n        return self\n\n    @only_allow_module\n    def init_states(self, delta_t: float = 0.025):\n        # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n        \"\"\"Initialize all mechanisms in their steady state.\n\n        This considers the voltages and parameters of each compartment.\n\n        Args:\n            delta_t: Passed on to `channel.init_state()`.\n        \"\"\"\n        # Update states of the channels.\n        channel_nodes = self.base.nodes\n        states = self.base._get_states_from_nodes_and_edges()\n\n        # We do not use any `pstate` for initializing. In principle, we could change\n        # that by allowing an input `params` and `pstate` to this function.\n        # `voltage_solver` could also be `jax.sparse` here, because both of them\n        # build the channel parameters in the same way.\n        params = self.base.get_all_parameters([], voltage_solver=\"jaxley.thomas\")\n\n        for channel in self.base.channels:\n            name = channel._name\n            channel_indices = channel_nodes.loc[channel_nodes[name]][\n                \"global_comp_index\"\n            ].to_numpy()\n            voltages = channel_nodes.loc[channel_indices, \"v\"].to_numpy()\n\n            channel_param_names = list(channel.channel_params.keys())\n            channel_state_names = list(channel.channel_states.keys())\n            channel_states = query_channel_states_and_params(\n                states, channel_state_names, channel_indices\n            )\n            channel_params = query_channel_states_and_params(\n                params, channel_param_names, channel_indices\n            )\n\n            init_state = channel.init_state(\n                channel_states, voltages, channel_params, delta_t\n            )\n\n            # `init_state` might not return all channel states. Only the ones that are\n            # returned are updated here.\n            for key, val in init_state.items():\n                # Note that we are overriding `self.nodes` here, but `self.nodes` is\n                # not used above to actually compute the current states (so there are\n                # no issues with overriding states).\n                self.nodes.loc[channel_indices, key] = val\n\n    def _init_morph_for_debugging(self):\n        \"\"\"Instandiates row and column inds which can be used to solve the voltage eqs.\n\n        This is important only for expert users who try to modify the solver for the\n        voltage equations. By default, this function is never run.\n\n        This is useful for debugging the solver because one can use\n        `scipy.linalg.sparse.spsolve` after every step of the solve.\n\n        Here is the code snippet that can be used for debugging then (to be inserted in\n        `solver_voltage`):\n        ```python\n        from scipy.sparse import csc_matrix\n        from scipy.sparse.linalg import spsolve\n        from jaxley.utils.debug_solver import build_voltage_matrix_elements\n\n        elements, solve, num_entries, start_ind_for_branchpoints = (\n            build_voltage_matrix_elements(\n                uppers,\n                lowers,\n                diags,\n                solves,\n                branchpoint_conds_children[debug_states[\"child_inds\"]],\n                branchpoint_conds_parents[debug_states[\"par_inds\"]],\n                branchpoint_weights_children[debug_states[\"child_inds\"]],\n                branchpoint_weights_parents[debug_states[\"par_inds\"]],\n                branchpoint_diags,\n                branchpoint_solves,\n                debug_states[\"ncomp\"],\n                nbranches,\n            )\n        )\n        sparse_matrix = csc_matrix(\n            (elements, (debug_states[\"row_inds\"], debug_states[\"col_inds\"])),\n            shape=(num_entries, num_entries),\n        )\n        solution = spsolve(sparse_matrix, solve)\n        solution = solution[:start_ind_for_branchpoints]  # Delete branchpoint voltages.\n        solves = jnp.reshape(solution, (debug_states[\"ncomp\"], nbranches))\n        return solves\n        ```\n        \"\"\"\n        # For scipy and jax.scipy.\n        row_and_col_inds = compute_morphology_indices(\n            len(self.base._par_inds),\n            self.base._child_belongs_to_branchpoint,\n            self.base._par_inds,\n            self.base._child_inds,\n            self.base.ncomp,\n            self.base.total_nbranches,\n        )\n\n        num_elements = len(row_and_col_inds[\"row_inds\"])\n        data_inds, indices, indptr = convert_to_csc(\n            num_elements=num_elements,\n            row_ind=row_and_col_inds[\"row_inds\"],\n            col_ind=row_and_col_inds[\"col_inds\"],\n        )\n        self.base.debug_states[\"row_inds\"] = row_and_col_inds[\"row_inds\"]\n        self.base.debug_states[\"col_inds\"] = row_and_col_inds[\"col_inds\"]\n        self.base.debug_states[\"data_inds\"] = data_inds\n        self.base.debug_states[\"indices\"] = indices\n        self.base.debug_states[\"indptr\"] = indptr\n\n        self.base.debug_states[\"ncomp\"] = self.base.ncomp\n        self.base.debug_states[\"child_inds\"] = self.base._child_inds\n        self.base.debug_states[\"par_inds\"] = self.base._par_inds\n\n    def record(self, state: str = \"v\", verbose=True):\n        comp_states, edge_states = self._get_state_names()\n        if state not in comp_states + edge_states:\n            raise KeyError(f\"{state} is not a recognized state in this module.\")\n        in_view = self._nodes_in_view if state in comp_states else self._edges_in_view\n\n        new_recs = pd.DataFrame(in_view, columns=[\"rec_index\"])\n        new_recs[\"state\"] = state\n        self.base.recordings = pd.concat([self.base.recordings, new_recs])\n        has_duplicates = self.base.recordings.duplicated()\n        self.base.recordings = self.base.recordings.loc[~has_duplicates]\n        if verbose:\n            print(\n                f\"Added {len(in_view)-sum(has_duplicates)} recordings. See `.recordings` for details.\"\n            )\n\n    def _update_view(self):\n        \"\"\"Update the attrs of the view after changes in the base module.\"\"\"\n        if isinstance(self, View):\n            scope = self._scope\n            current_view = self._current_view\n            # copy dict of new View. For some reason doing self = View(self)\n            # did not work.\n            self.__dict__ = View(\n                self.base, self._nodes_in_view, self._edges_in_view\n            ).__dict__\n\n            # retain the scope and current_view of the previous view\n            self._scope = scope\n            self._current_view = current_view\n\n    def delete_recordings(self):\n        \"\"\"Removes all recordings from the module.\"\"\"\n        if isinstance(self, View):\n            base_recs = self.base.recordings\n            self.base.recordings = base_recs[\n                ~base_recs.isin(self.recordings).all(axis=1)\n            ]\n            self._update_view()\n        else:\n            self.base.recordings = pd.DataFrame().from_dict({})\n\n    def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n        \"\"\"Insert a stimulus into the compartment.\n\n        current must be a 1d array or have batch dimension of size `(num_compartments, )`\n        or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n        This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n        it should only be used for static stimuli (i.e., stimuli that do not depend\n        on the data and that should not be learned). For stimuli that depend on data\n        (or that should be learned), please use `data_stimulate()`.\n\n        Args:\n            current: Current in `nA`.\n        \"\"\"\n        self._external_input(\"i\", current, verbose=verbose)\n\n    def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n        \"\"\"Clamp a state to a given value across specified compartments.\n\n        Args:\n            state_name: The name of the state to clamp.\n            state_array (jnp.nd: Array of values to clamp the state to.\n            verbose : If True, prints details about the clamping.\n\n        This function sets external states for the compartments.\n        \"\"\"\n        self._external_input(state_name, state_array, verbose=verbose)\n\n    def _external_input(\n        self,\n        key: str,\n        values: Optional[jnp.ndarray],\n        verbose: bool = True,\n    ):\n        comp_states, edge_states = self._get_state_names()\n        if key not in comp_states + edge_states:\n            raise KeyError(f\"{key} is not a recognized state in this module.\")\n        values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0)\n        batch_size = values.shape[0]\n        num_inserted = (\n            len(self._nodes_in_view) if key in comp_states else len(self._edges_in_view)\n        )\n        is_multiple = num_inserted == batch_size\n        values = values if is_multiple else jnp.repeat(values, num_inserted, axis=0)\n        assert batch_size in [\n            1,\n            num_inserted,\n        ], \"Number of comps and stimuli do not match.\"\n\n        if key in self.base.externals.keys():\n            self.base.externals[key] = jnp.concatenate(\n                [self.base.externals[key], values]\n            )\n            self.base.external_inds[key] = jnp.concatenate(\n                [self.base.external_inds[key], self._nodes_in_view]\n            )\n        else:\n            if key in comp_states:\n                self.base.externals[key] = values\n                self.base.external_inds[key] = self._nodes_in_view\n            else:\n                self.base.externals[key] = values\n                self.base.external_inds[key] = self._edges_in_view\n        if verbose:\n            print(\n                f\"Added {num_inserted} external_states. See `.externals` for details.\"\n            )\n\n    def data_stimulate(\n        self,\n        current: jnp.ndarray,\n        data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n        verbose: bool = False,\n    ) -> Tuple[jnp.ndarray, pd.DataFrame]:\n        \"\"\"Insert a stimulus into the module within jit (or grad).\n\n        Args:\n            current: Current in `nA`.\n            verbose: Whether or not to print the number of inserted stimuli. `False`\n                by default because this method is meant to be jitted.\n        \"\"\"\n        return self._data_external_input(\n            \"i\", current, data_stimuli, self.nodes, verbose=verbose\n        )\n\n    def data_clamp(\n        self,\n        state_name: str,\n        state_array: jnp.ndarray,\n        data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n        verbose: bool = False,\n    ):\n        \"\"\"Insert a clamp into the module within jit (or grad).\n\n        Args:\n            state_name: Name of the state variable to set.\n            state_array: Time series of the state variable in the default Jaxley unit.\n                State array should be of shape (num_clamps, simulation_time) or\n                (simulation_time, ) for a single clamp.\n            verbose: Whether or not to print the number of inserted clamps. `False`\n                by default because this method is meant to be jitted.\n        \"\"\"\n        comp_states, edge_states = self._get_state_names()\n        if state_name not in comp_states + edge_states:\n            raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n        data = self.nodes if state_name in comp_states else self.edges\n        return self._data_external_input(\n            state_name, state_array, data_clamps, data, verbose=verbose\n        )\n\n    def _data_external_input(\n        self,\n        state_name: str,\n        state_array: jnp.ndarray,\n        data_external_input: Optional[Tuple[jnp.ndarray, pd.DataFrame]],\n        view: pd.DataFrame,\n        verbose: bool = False,\n    ):\n        comp_states, edge_states = self._get_state_names()\n        state_array = (\n            state_array\n            if state_array.ndim == 2\n            else jnp.expand_dims(state_array, axis=0)\n        )\n        batch_size = state_array.shape[0]\n        num_inserted = (\n            len(self._nodes_in_view)\n            if state_name in comp_states\n            else len(self._edges_in_view)\n        )\n        is_multiple = num_inserted == batch_size\n        state_array = (\n            state_array\n            if is_multiple\n            else jnp.repeat(state_array, num_inserted, axis=0)\n        )\n        assert batch_size in [\n            1,\n            num_inserted,\n        ], \"Number of comps and clamps do not match.\"\n\n        if data_external_input is not None:\n            external_input = data_external_input[1]\n            external_input = jnp.concatenate([external_input, state_array])\n            inds = data_external_input[2]\n        else:\n            external_input = state_array\n            inds = pd.DataFrame().from_dict({})\n\n        inds = pd.concat([inds, view])\n\n        if verbose:\n            if state_name == \"i\":\n                print(f\"Added {len(view)} stimuli.\")\n            else:\n                print(f\"Added {len(view)} clamps.\")\n\n        return (state_name, external_input, inds)\n\n    def delete_stimuli(self):\n        \"\"\"Removes all stimuli from the module.\"\"\"\n        self.delete_clamps(\"i\")\n\n    def delete_clamps(self, state_name: Optional[str] = None):\n        \"\"\"Removes all clamps of the given state from the module.\"\"\"\n        all_externals = list(self.externals.keys())\n        if \"i\" in all_externals:\n            all_externals.remove(\"i\")\n        state_names = all_externals if state_name is None else [state_name]\n        for state_name in state_names:\n            if state_name in self.externals:\n                keep_inds = ~np.isin(\n                    self.base.external_inds[state_name], self._nodes_in_view\n                )\n                base_exts = self.base.externals\n                base_exts_inds = self.base.external_inds\n                if np.all(~keep_inds):\n                    base_exts.pop(state_name, None)\n                    base_exts_inds.pop(state_name, None)\n                else:\n                    base_exts[state_name] = base_exts[state_name][keep_inds]\n                    base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]\n                self._update_view()\n            else:\n                pass  # does not have to be deleted if not in externals\n\n    def insert(self, channel: Channel):\n        \"\"\"Insert a channel into the module.\n\n        Args:\n            channel: The channel to insert.\"\"\"\n        name = channel._name\n\n        # Channel does not yet exist in the `jx.Module` at all.\n        if name not in [c._name for c in self.base.channels]:\n            self.base.channels.append(channel)\n            self.base.nodes[name] = (\n                False  # Previous columns do not have the new channel.\n            )\n\n        if channel.current_name not in self.base.membrane_current_names:\n            self.base.membrane_current_names.append(channel.current_name)\n\n        # Add a binary column that indicates if a channel is present.\n        self.base.nodes.loc[self._nodes_in_view, name] = True\n\n        # Loop over all new parameters, e.g. gNa, eNa.\n        for key in channel.channel_params:\n            self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key]\n\n        # Loop over all new parameters, e.g. gNa, eNa.\n        for key in channel.channel_states:\n            self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]\n\n    def delete_channel(self, channel: Channel):\n        \"\"\"Remove a channel from the module.\n\n        Args:\n            channel: The channel to remove.\"\"\"\n        name = channel._name\n        channel_names = [c._name for c in self.channels]\n        all_channel_names = [c._name for c in self.base.channels]\n        if name in channel_names:\n            channel_cols = list(channel.channel_params.keys())\n            channel_cols += list(channel.channel_states.keys())\n            self.base.nodes.loc[self._nodes_in_view, channel_cols] = float(\"nan\")\n            self.base.nodes.loc[self._nodes_in_view, name] = False\n\n            # only delete cols if no other comps in the module have the same channel\n            if np.all(~self.base.nodes[name]):\n                self.base.channels.pop(all_channel_names.index(name))\n                self.base.membrane_current_names.remove(channel.current_name)\n                self.base.nodes.drop(columns=channel_cols + [name], inplace=True)\n        else:\n            raise ValueError(f\"Channel {name} not found in the module.\")\n\n    @only_allow_module\n    def step(\n        self,\n        u: Dict[str, jnp.ndarray],\n        delta_t: float,\n        external_inds: Dict[str, jnp.ndarray],\n        externals: Dict[str, jnp.ndarray],\n        params: Dict[str, jnp.ndarray],\n        solver: str = \"bwd_euler\",\n        voltage_solver: str = \"jaxley.stone\",\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"One step of solving the Ordinary Differential Equation.\n\n        This function is called inside of `integrate` and increments the state of the\n        module by one time step. Calls `_step_channels` and `_step_synapse` to update\n        the states of the channels and synapses using fwd_euler.\n\n        Args:\n            u: The state of the module. voltages = u[\"v\"]\n            delta_t: The time step.\n            external_inds: The indices of the external inputs.\n            externals: The external inputs.\n            params: The parameters of the module.\n            solver: The solver to use for the voltages. Either of [\"bwd_euler\",\n                \"fwd_euler\", \"crank_nicolson\"].\n            voltage_solver: The tridiagonal solver used to diagonalize the\n                coefficient matrix of the ODE system. Either of [\"jaxley.thomas\",\n                \"jaxley.stone\"].\n\n        Returns:\n            The updated state of the module.\n        \"\"\"\n\n        # Extract the voltages\n        voltages = u[\"v\"]\n\n        # Extract the external inputs\n        if \"i\" in externals.keys():\n            i_current = externals[\"i\"]\n            i_inds = external_inds[\"i\"]\n            i_ext = self._get_external_input(\n                voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n            )\n        else:\n            i_ext = 0.0\n\n        # Step of the channels.\n        u, (v_terms, const_terms) = self._step_channels(\n            u, delta_t, self.channels, self.nodes, params\n        )\n\n        # Step of the synapse.\n        u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n            u,\n            self.synapses,\n            params,\n            delta_t,\n            self.edges,\n        )\n\n        # Clamp for channels and synapses.\n        for key in externals.keys():\n            if key not in [\"i\", \"v\"]:\n                u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n        # Voltage steps.\n        cm = params[\"capacitance\"]  # Abbreviation.\n\n        # Arguments used by all solvers.\n        solver_kwargs = {\n            \"voltages\": voltages,\n            \"voltage_terms\": (v_terms + syn_v_terms) / cm,\n            \"constant_terms\": (const_terms + i_ext + syn_const_terms) / cm,\n            \"axial_conductances\": params[\"axial_conductances\"],\n            \"internal_node_inds\": self._internal_node_inds,\n        }\n\n        # Add solver specific arguments.\n        if voltage_solver == \"jax.sparse\":\n            solver_kwargs.update(\n                {\n                    \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n                    \"data_inds\": self._data_inds,\n                    \"indices\": self._indices_jax_spsolve,\n                    \"indptr\": self._indptr_jax_spsolve,\n                    \"n_nodes\": self._n_nodes,\n                }\n            )\n            # Only for `bwd_euler` and `cranck-nicolson`.\n            step_voltage_implicit = step_voltage_implicit_with_jax_spsolve\n        else:\n            # Our custom sparse solver requires a different format of all conductance\n            # values to perform triangulation and backsubstution optimally.\n            #\n            # Currently, the forward Euler solver also uses this format. However,\n            # this is only for historical reasons and we are planning to change this in\n            # the future.\n            solver_kwargs.update(\n                {\n                    \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n                    \"sources\": np.asarray(self._comp_edges[\"source\"].to_list()),\n                    \"types\": np.asarray(self._comp_edges[\"type\"].to_list()),\n                    \"ncomp_per_branch\": self.ncomp_per_branch,\n                    \"par_inds\": self._par_inds,\n                    \"child_inds\": self._child_inds,\n                    \"nbranches\": self.total_nbranches,\n                    \"solver\": voltage_solver,\n                    \"idx\": self._solve_indexer,\n                    \"debug_states\": self.debug_states,\n                }\n            )\n            # Only for `bwd_euler` and `cranck-nicolson`.\n            step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve\n\n        if solver == \"bwd_euler\":\n            u[\"v\"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)\n        elif solver == \"crank_nicolson\":\n            # Crank-Nicolson advances by half a step of backward and half a step of\n            # forward Euler.\n            half_step_delta_t = delta_t / 2\n            half_step_voltages = step_voltage_implicit(\n                **solver_kwargs, delta_t=half_step_delta_t\n            )\n            # The forward Euler step in Crank-Nicolson can be performed easily as\n            # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4.\n            u[\"v\"] = 2 * half_step_voltages - voltages\n        elif solver == \"fwd_euler\":\n            u[\"v\"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t)\n        else:\n            raise ValueError(\n                f\"You specified `solver={solver}`. The only allowed solvers are \"\n                \"['bwd_euler', 'fwd_euler', 'crank_nicolson'].\"\n            )\n\n        # Clamp for voltages.\n        if \"v\" in externals.keys():\n            u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n        return u\n\n    def _step_channels(\n        self,\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"One step of integration of the channels and of computing their current.\"\"\"\n        states = self._step_channels_state(\n            states, delta_t, channels, channel_nodes, params\n        )\n        states, current_terms = self._channel_currents(\n            states, delta_t, channels, channel_nodes, params\n        )\n        return states, current_terms\n\n    def _step_channels_state(\n        self,\n        states,\n        delta_t,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"One integration step of the channels.\"\"\"\n        voltages = states[\"v\"]\n\n        # Update states of the channels.\n        indices = channel_nodes[\"global_comp_index\"].to_numpy()\n        for channel in channels:\n            channel_param_names = list(channel.channel_params)\n            channel_param_names += [\n                \"radius\",\n                \"length\",\n                \"axial_resistivity\",\n                \"capacitance\",\n            ]\n            channel_state_names = list(channel.channel_states)\n            channel_state_names += self.membrane_current_names\n            channel_indices = indices[channel_nodes[channel._name].astype(bool)]\n\n            channel_params = query_channel_states_and_params(\n                params, channel_param_names, channel_indices\n            )\n            channel_states = query_channel_states_and_params(\n                states, channel_state_names, channel_indices\n            )\n\n            states_updated = channel.update_states(\n                channel_states, delta_t, voltages[channel_indices], channel_params\n            )\n            # Rebuild state. This has to be done within the loop over channels to allow\n            # multiple channels which modify the same state.\n            for key, val in states_updated.items():\n                states[key] = states[key].at[channel_indices].set(val)\n\n        return states\n\n    def _channel_currents(\n        self,\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Return the current through each channel.\n\n        This is also updates `state` because the `state` also contains the current.\n        \"\"\"\n        voltages = states[\"v\"]\n\n        # Compute current through channels.\n        voltage_terms = jnp.zeros_like(voltages)\n        constant_terms = jnp.zeros_like(voltages)\n        # Run with two different voltages that are `diff` apart to infer the slope and\n        # offset.\n        diff = 1e-3\n\n        current_states = {}\n        for name in self.membrane_current_names:\n            current_states[name] = jnp.zeros_like(voltages)\n\n        for channel in channels:\n            name = channel._name\n            channel_param_names = list(channel.channel_params.keys())\n            channel_state_names = list(channel.channel_states.keys())\n            indices = channel_nodes.loc[channel_nodes[name]][\n                \"global_comp_index\"\n            ].to_numpy()\n\n            channel_params = {}\n            for p in channel_param_names:\n                channel_params[p] = params[p][indices]\n            channel_params[\"radius\"] = params[\"radius\"][indices]\n            channel_params[\"length\"] = params[\"length\"][indices]\n            channel_params[\"axial_resistivity\"] = params[\"axial_resistivity\"][indices]\n\n            channel_states = {}\n            for s in channel_state_names:\n                channel_states[s] = states[s][indices]\n\n            v_and_perturbed = jnp.stack([voltages[indices], voltages[indices] + diff])\n            membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))(\n                channel_states, v_and_perturbed, channel_params\n            )\n            voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff\n            constant_term = membrane_currents[0] - voltage_term * voltages[indices]\n\n            # * 1000 to convert from mA/cm^2 to uA/cm^2.\n            voltage_terms = voltage_terms.at[indices].add(voltage_term * 1000.0)\n            constant_terms = constant_terms.at[indices].add(-constant_term * 1000.0)\n\n            # Save the current (for the unperturbed voltage) as a state that will\n            # also be passed to the state update.\n            current_states[channel.current_name] = (\n                current_states[channel.current_name]\n                .at[indices]\n                .add(membrane_currents[0])\n            )\n\n        # Copy the currents into the `state` dictionary such that they can be\n        # recorded and used by `Channel.update_states()`.\n        for name in self.membrane_current_names:\n            states[name] = current_states[name]\n\n        return states, (voltage_terms, constant_terms)\n\n    def _step_synapse(\n        self,\n        u: Dict[str, jnp.ndarray],\n        syn_channels: List[Channel],\n        params: Dict[str, jnp.ndarray],\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"One step of integration of the channels.\n\n        `Network` overrides this method (because it actually has synapses), whereas\n        `Compartment`, `Branch`, and `Cell` do not override this.\n        \"\"\"\n        voltages = u[\"v\"]\n        return u, (jnp.zeros_like(voltages), jnp.zeros_like(voltages))\n\n    def _synapse_currents(\n        self, states, syn_channels, params, delta_t, edges: pd.DataFrame\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        return states, (None, None)\n\n    @staticmethod\n    def _get_external_input(\n        voltages: jnp.ndarray,\n        i_inds: jnp.ndarray,\n        i_stim: jnp.ndarray,\n        radius: float,\n        length_single_compartment: float,\n    ) -> jnp.ndarray:\n        \"\"\"\n        Return external input to each compartment in uA / cm^2.\n\n        Args:\n            voltages: mV.\n            i_stim: nA.\n            radius: um.\n            length_single_compartment: um.\n        \"\"\"\n        zero_vec = jnp.zeros_like(voltages)\n        current = convert_point_process_to_distributed(\n            i_stim, radius[i_inds], length_single_compartment[i_inds]\n        )\n\n        dnums = ScatterDimensionNumbers(\n            update_window_dims=(),\n            inserted_window_dims=(0,),\n            scatter_dims_to_operand_dims=(0,),\n        )\n        stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)\n        return stim_at_timestep\n\n    def vis(\n        self,\n        ax: Optional[Axes] = None,\n        color: str = \"k\",\n        dims: Tuple[int] = (0, 1),\n        type: str = \"line\",\n        **kwargs,\n    ) -> Axes:\n        \"\"\"Visualize the module.\n\n        Modules can be visualized on one of the cardinal planes (xy, xz, yz) or\n        even in 3D.\n\n        Several options are available:\n        - `line`: All points from the traced morphology (`xyzr`), are connected\n        with a line plot.\n        - `scatter`: All traced points, are plotted as scatter points.\n        - `comp`: Plots the compartmentalized morphology, including radius\n        and shape. (shows the true compartment lengths per default, but this can\n        be changed via the `kwargs`, for details see\n        `jaxley.utils.plot_utils.plot_comps`).\n        - `morph`: Reconstructs the 3D shape of the traced morphology. For details see\n        `jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies\n        with many traced points this can be very slow.\n\n        Args:\n            ax: An axis into which to plot.\n            color: The color for all branches.\n            dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n                two of them.\n            type: The type of plot. One of [\"line\", \"scatter\", \"comp\", \"morph\"].\n            kwargs: Keyword arguments passed to the plotting function.\n        \"\"\"\n        res = 100 if \"resolution\" not in kwargs else kwargs.pop(\"resolution\")\n        if \"comp\" in type.lower():\n            return plot_comps(\n                self, dims=dims, ax=ax, color=color, resolution=res, **kwargs\n            )\n        if \"morph\" in type.lower():\n            return plot_morph(\n                self, dims=dims, ax=ax, color=color, resolution=res, **kwargs\n            )\n\n        assert not np.any(\n            [np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]\n        ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n\n        ax = plot_graph(\n            self.xyzr,\n            dims=dims,\n            color=color,\n            ax=ax,\n            type=type,\n            **kwargs,\n        )\n\n        return ax\n\n    def compute_xyz(self):\n        \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n        This function should not be called if the morphology was read from an `.swc`\n        file. However, for morphologies that were constructed from scratch, this\n        function **must** be called before `.vis()`. The computed `xyz` coordinates\n        are only used for plotting.\n        \"\"\"\n        max_y_multiplier = 5.0\n        min_y_multiplier = 0.5\n\n        parents = self.comb_parents\n        num_children = _compute_num_children(parents)\n        index_of_child = _compute_index_of_child(parents)\n        levels = compute_levels(parents)\n\n        # Extract branch.\n        inds_branch = self.nodes.groupby(\"global_branch_index\")[\n            \"global_comp_index\"\n        ].apply(list)\n        branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n        endpoints = []\n\n        # Different levels will get a different \"angle\" at which the children emerge from\n        # the parents. This angle is defined by the `y_offset_multiplier`. This value\n        # defines the range between y-location of the first and of the last child of a\n        # parent.\n        y_offset_multiplier = np.linspace(\n            max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n        )\n\n        for b in range(self.total_nbranches):\n            # For networks with mixed SWC and from-scatch neurons, only update those\n            # branches that do not have coordingates yet.\n            if np.any(np.isnan(self.xyzr[b])):\n                if parents[b] > -1:\n                    start_point = endpoints[parents[b]]\n                    num_children_of_parent = num_children[parents[b]]\n                    if num_children_of_parent > 1:\n                        y_offset = (\n                            ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n                        ) * y_offset_multiplier[levels[b]]\n                    else:\n                        y_offset = 0.0\n                else:\n                    start_point = [0, 0, 0]\n                    y_offset = 0.0\n\n                len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n                end_point = [\n                    start_point[0] + branch_lens[b] / len_of_path * 1.0,\n                    start_point[1] + branch_lens[b] / len_of_path * y_offset,\n                    start_point[2],\n                ]\n                endpoints.append(end_point)\n\n                self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n            else:\n                # Dummy to keey the index `endpoints[parent[b]]` above working.\n                endpoints.append(np.zeros((2,)))\n\n    def move(\n        self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = False\n    ):\n        \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n        This function is used only for visualization. It does not affect the simulation.\n\n        Args:\n            x: The amount to move in the x direction in um.\n            y: The amount to move in the y direction in um.\n            z: The amount to move in the z direction in um.\n            update_nodes: Whether `.nodes` should be updated or not. Setting this to\n                `False` largely speeds up moving, especially for big networks, but\n                `.nodes` or `.show` will not show the new xyz coordinates.\n        \"\"\"\n        for i in self._branches_in_view:\n            self.base.xyzr[i][:, :3] += np.array([x, y, z])\n        if update_nodes:\n            self.compute_compartment_centers()\n\n    def move_to(\n        self,\n        x: Union[float, np.ndarray] = 0.0,\n        y: Union[float, np.ndarray] = 0.0,\n        z: Union[float, np.ndarray] = 0.0,\n        update_nodes: bool = False,\n    ):\n        \"\"\"Move cells or networks to a location (x, y, z).\n\n        If x, y, and z are floats, then the first compartment of the first branch\n        of the first cell is moved to that float coordinate, and everything else is\n        shifted by the difference between that compartment's previous coordinate and\n        the new float location.\n\n        If x, y, and z are arrays, then they must each have a length equal to the number\n        of cells being moved. Then the first compartment of the first branch of each\n        cell is moved to the specified location.\n\n        Args:\n            update_nodes: Whether `.nodes` should be updated or not. Setting this to\n                `False` largely speeds up moving, especially for big networks, but\n                `.nodes` or `.show` will not show the new xyz coordinates.\n        \"\"\"\n        # Test if any coordinate values are NaN which would greatly affect moving\n        if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n            raise ValueError(\n                \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n            )\n\n        # can only iterate over cells for networks\n        # lambda makes sure that generator can be created multiple times\n        base_is_net = self.base._current_view == \"network\"\n        cells = lambda: (self.cells if base_is_net else [self])\n\n        root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()])\n        root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells\n        move_by = np.array([x, y, z]).T - root_xyz\n\n        if len(move_by.shape) == 1:\n            move_by = np.tile(move_by, (len(self._cells_in_view), 1))\n\n        for cell, offset in zip(cells(), move_by):\n            for idx in cell._branches_in_view:\n                self.base.xyzr[idx][:, :3] += offset\n        if update_nodes:\n            self.compute_compartment_centers()\n\n    def rotate(\n        self, degrees: float, rotation_axis: str = \"xy\", update_nodes: bool = False\n    ):\n        \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n        This function is used only for visualization. It does not affect the simulation.\n\n        Args:\n            degrees: How many degrees to rotate the module by.\n            rotation_axis: Either of {`xy` | `xz` | `yz`}.\n        \"\"\"\n        degrees = degrees / 180 * np.pi\n        if rotation_axis == \"xy\":\n            dims = [0, 1]\n        elif rotation_axis == \"xz\":\n            dims = [0, 2]\n        elif rotation_axis == \"yz\":\n            dims = [1, 2]\n        else:\n            raise ValueError\n\n        rotation_matrix = np.asarray(\n            [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]\n        )\n        for i in self._branches_in_view:\n            rot = np.dot(rotation_matrix, self.base.xyzr[i][:, dims].T).T\n            self.base.xyzr[i][:, dims] = rot\n        if update_nodes:\n            self.compute_compartment_centers()\n\n    def copy_node_property_to_edges(\n        self,\n        properties_to_import: Union[str, List[str]],\n        pre_or_post: Union[str, List[str]] = [\"pre\", \"post\"],\n    ) -> Module:\n        \"\"\"Copy a property that is in `node` over to `edges`.\n\n        By default, `.edges` does not contain the properties (radius, length, cm,\n        channel properties,...) of the pre- and post-synaptic compartments. This\n        method allows to copy a property of the pre- and/or post-synaptic compartment\n        to the edges. It is then accessible as `module.edges.pre_property_name` or\n        `module.edges.post_property_name`.\n\n        Note that, if you modify the node property _after_ having run\n        `copy_node_property_to_edges`, it will not automatically update the value in\n        `.edges`.\n\n        Note that, if this method is called on a View (e.g.\n        `net.cell(0).copy_node_property_to_edges`), then it will return a View, but\n        it will _not_ modify the module itself.\n\n        Args:\n            properties_to_import: The name of the node properties that should be\n                imported. To list all available properties, look at\n                `module.nodes.columns`.\n            pre_or_post: Whether to import only the pre-synaptic property ('pre'), only\n                the post-synaptic property ('post'), or both (['pre', 'post']).\n\n        Returns:\n            A new module which has the property copied to the `nodes`.\n        \"\"\"\n        # If a string is passed, wrap it as a list.\n        if isinstance(pre_or_post, str):\n            pre_or_post = [pre_or_post]\n        if isinstance(properties_to_import, str):\n            properties_to_import = [properties_to_import]\n\n        for pre_or_post_val in pre_or_post:\n            assert pre_or_post_val in [\"pre\", \"post\"]\n            for property_to_import in properties_to_import:\n                # Delete the column if it already exists. Otherwise it would exist\n                # twice.\n                if f\"{pre_or_post_val}_{property_to_import}\" in self.edges.columns:\n                    self.edges.drop(\n                        columns=f\"{pre_or_post_val}_{property_to_import}\", inplace=True\n                    )\n\n                self.edges = self.edges.join(\n                    self.nodes[[property_to_import, \"global_comp_index\"]].set_index(\n                        \"global_comp_index\"\n                    ),\n                    on=f\"{pre_or_post_val}_global_comp_index\",\n                )\n                self.edges = self.edges.rename(\n                    columns={\n                        property_to_import: f\"{pre_or_post_val}_{property_to_import}\"\n                    }\n                )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.branches","title":"branches property","text":"

Iterate over all branches in the module.

Returns a generator that yields a View of each branch.

"},{"location":"reference/modules/#jaxley.modules.base.Module.cells","title":"cells property","text":"

Iterate over all cells in the module.

Returns a generator that yields a View of each cell.

"},{"location":"reference/modules/#jaxley.modules.base.Module.comps","title":"comps property","text":"

Iterate over all compartments in the module. Can be called on any module, i.e. net.comps, cell.comps or branch.comps. __iter__ does not allow for this.

Returns a generator that yields a View of each compartment.

"},{"location":"reference/modules/#jaxley.modules.base.Module.initialized","title":"initialized: bool property","text":"

Whether the Module is ready to be solved or not.

"},{"location":"reference/modules/#jaxley.modules.base.Module.shape","title":"shape: Tuple[int] property","text":"

Returns the number of submodules contained in a module.

.. code-block:: python

network.shape = (num_cells, num_branches, num_compartments)\ncell.shape = (num_branches, num_compartments)\nbranch.shape = (num_compartments,)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.view","title":"view property","text":"

Return view of the module.

"},{"location":"reference/modules/#jaxley.modules.base.Module.__getitem__","title":"__getitem__(index)","text":"

Lazy indexing of the module.

Source code in jaxley/modules/base.py
def __getitem__(self, index):\n    \"\"\"Lazy indexing of the module.\"\"\"\n    supported_parents = [\"network\", \"cell\", \"branch\"]  # cannot index into comp\n\n    not_group_view = self._current_view not in self.groups\n    assert (\n        self._current_view in supported_parents or not_group_view\n    ), \"Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof.\"\n    index = index if isinstance(index, tuple) else (index,)\n\n    child_views = self._childviews()\n    assert len(index) <= len(child_views), \"Too many indices.\"\n    view = self\n    for i, child in zip(index, child_views):\n        view = view._at_nodes(child, i)\n    return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.__iter__","title":"__iter__()","text":"

Iterate over parts of the module.

Internally calls cells, branches, comps at the appropriate level.

Example:

.. code-block:: python

for cell in network:\n    for branch in cell:\n        for comp in branch:\n            print(comp.nodes.shape)\n
Source code in jaxley/modules/base.py
def __iter__(self):\n    \"\"\"Iterate over parts of the module.\n\n    Internally calls `cells`, `branches`, `comps` at the appropriate level.\n\n    Example:\n\n    .. code-block:: python\n\n        for cell in network:\n            for branch in cell:\n                for comp in branch:\n                    print(comp.nodes.shape)\n    \"\"\"\n    next_level = self._childviews()[0]\n    yield from self._iter_submodules(next_level)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.add_to_group","title":"add_to_group(group_name)","text":"

Add a view of the module to a group.

Groups can then be indexed. For example:

.. code-block:: python

net.cell(0).add_to_group(\"excitatory\")\nnet.excitatory.set(\"radius\", 0.1)\n

Parameters:

Name Type Description Default group_name str

The name of the group.

required Source code in jaxley/modules/base.py
def add_to_group(self, group_name: str):\n    \"\"\"Add a view of the module to a group.\n\n    Groups can then be indexed. For example:\n\n    .. code-block:: python\n\n        net.cell(0).add_to_group(\"excitatory\")\n        net.excitatory.set(\"radius\", 0.1)\n\n    Args:\n        group_name: The name of the group.\n    \"\"\"\n    if group_name not in self.base.groups:\n        self.base.groups[group_name] = self._nodes_in_view\n    else:\n        self.base.groups[group_name] = np.unique(\n            np.concatenate([self.base.groups[group_name], self._nodes_in_view])\n        )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.branch","title":"branch(idx)","text":"

Return a View of the module at the selected branches(s).

Parameters:

Name Type Description Default idx Any

index of the branch to view.

required

Returns:

Type Description View

View of the module at the specified branch index.

Source code in jaxley/modules/base.py
def branch(self, idx: Any) -> View:\n    \"\"\"Return a View of the module at the selected branches(s).\n\n    Args:\n        idx: index of the branch to view.\n\n    Returns:\n        View of the module at the specified branch index.\"\"\"\n    return self._at_nodes(\"branch\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.cell","title":"cell(idx)","text":"

Return a View of the module at the selected cell(s).

Parameters:

Name Type Description Default idx Any

index of the cell to view.

required

Returns:

Type Description View

View of the module at the specified cell index.

Source code in jaxley/modules/base.py
def cell(self, idx: Any) -> View:\n    \"\"\"Return a View of the module at the selected cell(s).\n\n    Args:\n        idx: index of the cell to view.\n\n    Returns:\n        View of the module at the specified cell index.\"\"\"\n    return self._at_nodes(\"cell\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.clamp","title":"clamp(state_name, state_array, verbose=True)","text":"

Clamp a state to a given value across specified compartments.

Parameters:

Name Type Description Default state_name str

The name of the state to clamp.

required state_array nd

Array of values to clamp the state to.

required verbose

If True, prints details about the clamping.

True

This function sets external states for the compartments.

Source code in jaxley/modules/base.py
def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n    \"\"\"Clamp a state to a given value across specified compartments.\n\n    Args:\n        state_name: The name of the state to clamp.\n        state_array (jnp.nd: Array of values to clamp the state to.\n        verbose : If True, prints details about the clamping.\n\n    This function sets external states for the compartments.\n    \"\"\"\n    self._external_input(state_name, state_array, verbose=verbose)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.comp","title":"comp(idx)","text":"

Return a View of the module at the selected compartments(s).

Parameters:

Name Type Description Default idx Any

index of the comp to view.

required

Returns:

Type Description View

View of the module at the specified compartment index.

Source code in jaxley/modules/base.py
def comp(self, idx: Any) -> View:\n    \"\"\"Return a View of the module at the selected compartments(s).\n\n    Args:\n        idx: index of the comp to view.\n\n    Returns:\n        View of the module at the specified compartment index.\"\"\"\n    return self._at_nodes(\"comp\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.compute_compartment_centers","title":"compute_compartment_centers()","text":"

Add compartment centers to nodes dataframe

Source code in jaxley/modules/base.py
def compute_compartment_centers(self):\n    \"\"\"Add compartment centers to nodes dataframe\"\"\"\n    centers = self._compute_coords_of_comp_centers()\n    self.base.nodes.loc[self._nodes_in_view, [\"x\", \"y\", \"z\"]] = centers\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.compute_xyz","title":"compute_xyz()","text":"

Return xyz coordinates of every branch, based on the branch length.

This function should not be called if the morphology was read from an .swc file. However, for morphologies that were constructed from scratch, this function must be called before .vis(). The computed xyz coordinates are only used for plotting.

Source code in jaxley/modules/base.py
def compute_xyz(self):\n    \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n    This function should not be called if the morphology was read from an `.swc`\n    file. However, for morphologies that were constructed from scratch, this\n    function **must** be called before `.vis()`. The computed `xyz` coordinates\n    are only used for plotting.\n    \"\"\"\n    max_y_multiplier = 5.0\n    min_y_multiplier = 0.5\n\n    parents = self.comb_parents\n    num_children = _compute_num_children(parents)\n    index_of_child = _compute_index_of_child(parents)\n    levels = compute_levels(parents)\n\n    # Extract branch.\n    inds_branch = self.nodes.groupby(\"global_branch_index\")[\n        \"global_comp_index\"\n    ].apply(list)\n    branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n    endpoints = []\n\n    # Different levels will get a different \"angle\" at which the children emerge from\n    # the parents. This angle is defined by the `y_offset_multiplier`. This value\n    # defines the range between y-location of the first and of the last child of a\n    # parent.\n    y_offset_multiplier = np.linspace(\n        max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n    )\n\n    for b in range(self.total_nbranches):\n        # For networks with mixed SWC and from-scatch neurons, only update those\n        # branches that do not have coordingates yet.\n        if np.any(np.isnan(self.xyzr[b])):\n            if parents[b] > -1:\n                start_point = endpoints[parents[b]]\n                num_children_of_parent = num_children[parents[b]]\n                if num_children_of_parent > 1:\n                    y_offset = (\n                        ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n                    ) * y_offset_multiplier[levels[b]]\n                else:\n                    y_offset = 0.0\n            else:\n                start_point = [0, 0, 0]\n                y_offset = 0.0\n\n            len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n            end_point = [\n                start_point[0] + branch_lens[b] / len_of_path * 1.0,\n                start_point[1] + branch_lens[b] / len_of_path * y_offset,\n                start_point[2],\n            ]\n            endpoints.append(end_point)\n\n            self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n        else:\n            # Dummy to keey the index `endpoints[parent[b]]` above working.\n            endpoints.append(np.zeros((2,)))\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.copy","title":"copy(reset_index=False, as_module=False)","text":"

Extract part of a module and return a copy of its View or a new module.

This can be used to call jx.integrate on part of a Module.

Parameters:

Name Type Description Default reset_index bool

if True, the indices of the new module are reset to start from 0.

False as_module bool

if True, a new module is returned instead of a View.

False

Returns:

Type Description Union[Module, View]

A part of the module or a copied view of it.

Source code in jaxley/modules/base.py
def copy(\n    self, reset_index: bool = False, as_module: bool = False\n) -> Union[Module, View]:\n    \"\"\"Extract part of a module and return a copy of its View or a new module.\n\n    This can be used to call `jx.integrate` on part of a Module.\n\n    Args:\n        reset_index: if True, the indices of the new module are reset to start from 0.\n        as_module: if True, a new module is returned instead of a View.\n\n    Returns:\n        A part of the module or a copied view of it.\"\"\"\n    view = deepcopy(self)\n    warnings.warn(\"This method is experimental, use at your own risk.\")\n    # TODO FROM #447: add reset_index, i.e. for parents, nodes, edges etc. such that they\n    # start from 0/-1 and are contiguous\n    if as_module:\n        raise NotImplementedError(\"Not yet implemented.\")\n        # initialize a new module with the same attributes\n    return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.copy_node_property_to_edges","title":"copy_node_property_to_edges(properties_to_import, pre_or_post=['pre', 'post'])","text":"

Copy a property that is in node over to edges.

By default, .edges does not contain the properties (radius, length, cm, channel properties,\u2026) of the pre- and post-synaptic compartments. This method allows to copy a property of the pre- and/or post-synaptic compartment to the edges. It is then accessible as module.edges.pre_property_name or module.edges.post_property_name.

Note that, if you modify the node property after having run copy_node_property_to_edges, it will not automatically update the value in .edges.

Note that, if this method is called on a View (e.g. net.cell(0).copy_node_property_to_edges), then it will return a View, but it will not modify the module itself.

Parameters:

Name Type Description Default properties_to_import Union[str, List[str]]

The name of the node properties that should be imported. To list all available properties, look at module.nodes.columns.

required pre_or_post Union[str, List[str]]

Whether to import only the pre-synaptic property (\u2018pre\u2019), only the post-synaptic property (\u2018post\u2019), or both ([\u2018pre\u2019, \u2018post\u2019]).

['pre', 'post']

Returns:

Type Description Module

A new module which has the property copied to the nodes.

Source code in jaxley/modules/base.py
def copy_node_property_to_edges(\n    self,\n    properties_to_import: Union[str, List[str]],\n    pre_or_post: Union[str, List[str]] = [\"pre\", \"post\"],\n) -> Module:\n    \"\"\"Copy a property that is in `node` over to `edges`.\n\n    By default, `.edges` does not contain the properties (radius, length, cm,\n    channel properties,...) of the pre- and post-synaptic compartments. This\n    method allows to copy a property of the pre- and/or post-synaptic compartment\n    to the edges. It is then accessible as `module.edges.pre_property_name` or\n    `module.edges.post_property_name`.\n\n    Note that, if you modify the node property _after_ having run\n    `copy_node_property_to_edges`, it will not automatically update the value in\n    `.edges`.\n\n    Note that, if this method is called on a View (e.g.\n    `net.cell(0).copy_node_property_to_edges`), then it will return a View, but\n    it will _not_ modify the module itself.\n\n    Args:\n        properties_to_import: The name of the node properties that should be\n            imported. To list all available properties, look at\n            `module.nodes.columns`.\n        pre_or_post: Whether to import only the pre-synaptic property ('pre'), only\n            the post-synaptic property ('post'), or both (['pre', 'post']).\n\n    Returns:\n        A new module which has the property copied to the `nodes`.\n    \"\"\"\n    # If a string is passed, wrap it as a list.\n    if isinstance(pre_or_post, str):\n        pre_or_post = [pre_or_post]\n    if isinstance(properties_to_import, str):\n        properties_to_import = [properties_to_import]\n\n    for pre_or_post_val in pre_or_post:\n        assert pre_or_post_val in [\"pre\", \"post\"]\n        for property_to_import in properties_to_import:\n            # Delete the column if it already exists. Otherwise it would exist\n            # twice.\n            if f\"{pre_or_post_val}_{property_to_import}\" in self.edges.columns:\n                self.edges.drop(\n                    columns=f\"{pre_or_post_val}_{property_to_import}\", inplace=True\n                )\n\n            self.edges = self.edges.join(\n                self.nodes[[property_to_import, \"global_comp_index\"]].set_index(\n                    \"global_comp_index\"\n                ),\n                on=f\"{pre_or_post_val}_global_comp_index\",\n            )\n            self.edges = self.edges.rename(\n                columns={\n                    property_to_import: f\"{pre_or_post_val}_{property_to_import}\"\n                }\n            )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_clamp","title":"data_clamp(state_name, state_array, data_clamps=None, verbose=False)","text":"

Insert a clamp into the module within jit (or grad).

Parameters:

Name Type Description Default state_name str

Name of the state variable to set.

required state_array ndarray

Time series of the state variable in the default Jaxley unit. State array should be of shape (num_clamps, simulation_time) or (simulation_time, ) for a single clamp.

required verbose bool

Whether or not to print the number of inserted clamps. False by default because this method is meant to be jitted.

False Source code in jaxley/modules/base.py
def data_clamp(\n    self,\n    state_name: str,\n    state_array: jnp.ndarray,\n    data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n    verbose: bool = False,\n):\n    \"\"\"Insert a clamp into the module within jit (or grad).\n\n    Args:\n        state_name: Name of the state variable to set.\n        state_array: Time series of the state variable in the default Jaxley unit.\n            State array should be of shape (num_clamps, simulation_time) or\n            (simulation_time, ) for a single clamp.\n        verbose: Whether or not to print the number of inserted clamps. `False`\n            by default because this method is meant to be jitted.\n    \"\"\"\n    comp_states, edge_states = self._get_state_names()\n    if state_name not in comp_states + edge_states:\n        raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n    data = self.nodes if state_name in comp_states else self.edges\n    return self._data_external_input(\n        state_name, state_array, data_clamps, data, verbose=verbose\n    )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_set","title":"data_set(key, val, param_state)","text":"

Set parameter of module (or its view) to a new value within jit.

Parameters:

Name Type Description Default key str

The name of the parameter to set.

required val Union[float, ndarray]

The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).

required param_state Optional[List[Dict]]

State of the setted parameters, internally used such that this function does not modify global state.

required Source code in jaxley/modules/base.py
def data_set(\n    self,\n    key: str,\n    val: Union[float, jnp.ndarray],\n    param_state: Optional[List[Dict]],\n):\n    \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n    Args:\n        key: The name of the parameter to set.\n        val: The value to set the parameter to. If it is `jnp.ndarray` then it\n            must be of shape `(len(num_compartments))`.\n        param_state: State of the setted parameters, internally used such that this\n            function does not modify global state.\n    \"\"\"\n    # Note: `data_set` does not support arrays for `val`.\n    is_node_param = key in self.nodes.columns\n    data = self.nodes if is_node_param else self.edges\n    viewed_inds = self._nodes_in_view if is_node_param else self._edges_in_view\n    if key in data.columns:\n        not_nan = ~data[key].isna()\n        added_param_state = [\n            {\n                \"indices\": np.atleast_2d(viewed_inds[not_nan]),\n                \"key\": key,\n                \"val\": jnp.atleast_1d(jnp.asarray(val)),\n            }\n        ]\n        if param_state is not None:\n            param_state += added_param_state\n        else:\n            param_state = added_param_state\n    else:\n        raise KeyError(\"Key not recognized.\")\n    return param_state\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_stimulate","title":"data_stimulate(current, data_stimuli=None, verbose=False)","text":"

Insert a stimulus into the module within jit (or grad).

Parameters:

Name Type Description Default current ndarray

Current in nA.

required verbose bool

Whether or not to print the number of inserted stimuli. False by default because this method is meant to be jitted.

False Source code in jaxley/modules/base.py
def data_stimulate(\n    self,\n    current: jnp.ndarray,\n    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n    verbose: bool = False,\n) -> Tuple[jnp.ndarray, pd.DataFrame]:\n    \"\"\"Insert a stimulus into the module within jit (or grad).\n\n    Args:\n        current: Current in `nA`.\n        verbose: Whether or not to print the number of inserted stimuli. `False`\n            by default because this method is meant to be jitted.\n    \"\"\"\n    return self._data_external_input(\n        \"i\", current, data_stimuli, self.nodes, verbose=verbose\n    )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_channel","title":"delete_channel(channel)","text":"

Remove a channel from the module.

Parameters:

Name Type Description Default channel Channel

The channel to remove.

required Source code in jaxley/modules/base.py
def delete_channel(self, channel: Channel):\n    \"\"\"Remove a channel from the module.\n\n    Args:\n        channel: The channel to remove.\"\"\"\n    name = channel._name\n    channel_names = [c._name for c in self.channels]\n    all_channel_names = [c._name for c in self.base.channels]\n    if name in channel_names:\n        channel_cols = list(channel.channel_params.keys())\n        channel_cols += list(channel.channel_states.keys())\n        self.base.nodes.loc[self._nodes_in_view, channel_cols] = float(\"nan\")\n        self.base.nodes.loc[self._nodes_in_view, name] = False\n\n        # only delete cols if no other comps in the module have the same channel\n        if np.all(~self.base.nodes[name]):\n            self.base.channels.pop(all_channel_names.index(name))\n            self.base.membrane_current_names.remove(channel.current_name)\n            self.base.nodes.drop(columns=channel_cols + [name], inplace=True)\n    else:\n        raise ValueError(f\"Channel {name} not found in the module.\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_clamps","title":"delete_clamps(state_name=None)","text":"

Removes all clamps of the given state from the module.

Source code in jaxley/modules/base.py
def delete_clamps(self, state_name: Optional[str] = None):\n    \"\"\"Removes all clamps of the given state from the module.\"\"\"\n    all_externals = list(self.externals.keys())\n    if \"i\" in all_externals:\n        all_externals.remove(\"i\")\n    state_names = all_externals if state_name is None else [state_name]\n    for state_name in state_names:\n        if state_name in self.externals:\n            keep_inds = ~np.isin(\n                self.base.external_inds[state_name], self._nodes_in_view\n            )\n            base_exts = self.base.externals\n            base_exts_inds = self.base.external_inds\n            if np.all(~keep_inds):\n                base_exts.pop(state_name, None)\n                base_exts_inds.pop(state_name, None)\n            else:\n                base_exts[state_name] = base_exts[state_name][keep_inds]\n                base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]\n            self._update_view()\n        else:\n            pass  # does not have to be deleted if not in externals\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_recordings","title":"delete_recordings()","text":"

Removes all recordings from the module.

Source code in jaxley/modules/base.py
def delete_recordings(self):\n    \"\"\"Removes all recordings from the module.\"\"\"\n    if isinstance(self, View):\n        base_recs = self.base.recordings\n        self.base.recordings = base_recs[\n            ~base_recs.isin(self.recordings).all(axis=1)\n        ]\n        self._update_view()\n    else:\n        self.base.recordings = pd.DataFrame().from_dict({})\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_stimuli","title":"delete_stimuli()","text":"

Removes all stimuli from the module.

Source code in jaxley/modules/base.py
def delete_stimuli(self):\n    \"\"\"Removes all stimuli from the module.\"\"\"\n    self.delete_clamps(\"i\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_trainables","title":"delete_trainables()","text":"

Removes all trainable parameters from the module.

Source code in jaxley/modules/base.py
def delete_trainables(self):\n    \"\"\"Removes all trainable parameters from the module.\"\"\"\n\n    if isinstance(self, View):\n        trainables_and_inds = self._filter_trainables(is_viewed=False)\n        self.base.indices_set_by_trainables = trainables_and_inds[0]\n        self.base.trainable_params = trainables_and_inds[1]\n        self.base.num_trainable_params -= self.num_trainable_params\n    else:\n        self.base.indices_set_by_trainables = []\n        self.base.trainable_params = []\n        self.base.num_trainable_params = 0\n    self._update_view()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.distance","title":"distance(endpoint)","text":"

Return the direct distance between two compartments. This does not compute the pathwise distance (which is currently not implemented). Args: endpoint: The compartment to which to compute the distance to.

Source code in jaxley/modules/base.py
def distance(self, endpoint: \"View\") -> float:\n    \"\"\"Return the direct distance between two compartments.\n    This does not compute the pathwise distance (which is currently not\n    implemented).\n    Args:\n        endpoint: The compartment to which to compute the distance to.\n    \"\"\"\n    assert len(self.xyzr) == 1 and len(endpoint.xyzr) == 1\n    start_xyz = np.mean(self.xyzr[0][:, :3], axis=0)\n    end_xyz = np.mean(endpoint.xyzr[0][:, :3], axis=0)\n    return np.sqrt(np.sum((start_xyz - end_xyz) ** 2))\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.edge","title":"edge(idx)","text":"

Return a View of the module at the selected synapse edges(s).

Parameters:

Name Type Description Default idx Any

index of the edge to view.

required

Returns:

Type Description View

View of the module at the specified edge index.

Source code in jaxley/modules/base.py
def edge(self, idx: Any) -> View:\n    \"\"\"Return a View of the module at the selected synapse edges(s).\n\n    Args:\n        idx: index of the edge to view.\n\n    Returns:\n        View of the module at the specified edge index.\"\"\"\n    return self._at_edges(\"edge\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_parameters","title":"get_all_parameters(pstate, voltage_solver)","text":"

Return all parameters (and coupling conductances) needed to simulate.

Runs _compute_axial_conductances() and return every parameter that is needed to solve the ODE. This includes conductances, radiuses, lengths, axial_resistivities, but also coupling conductances.

This is done by first obtaining the current value of every parameter (not only the trainable ones) and then replacing the trainable ones with the value in trainable_params(). This function is run within jx.integrate().

pstate can be obtained by calling params_to_pstate().

.. code-block:: python

params = module.get_parameters() # i.e. [0, 1, 2]\npstate = params_to_pstate(params, module.indices_set_by_trainables)\nmodule.to_jax() # needed for call to module.jaxnodes\n

Parameters:

Name Type Description Default pstate List[Dict]

The state of the trainable parameters. pstate takes the form [{ \u201ckey\u201d: \u201cgNa\u201d, \u201cindices\u201d: jnp.array([0, 1, 2]), \u201cval\u201d: jnp.array([0.1, 0.2, 0.3]) }, \u2026].

required voltage_solver str

The voltage solver that is used. Since jax.sparse and jaxley.xyz require different formats of the axial conductances, this function will default to different building methods.

required

Returns:

Type Description Dict[str, ndarray]

A dictionary of all module parameters.

Source code in jaxley/modules/base.py
@only_allow_module\ndef get_all_parameters(\n    self, pstate: List[Dict], voltage_solver: str\n) -> Dict[str, jnp.ndarray]:\n    # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n    \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n    Runs `_compute_axial_conductances()` and return every parameter that is needed\n    to solve the ODE. This includes conductances, radiuses, lengths,\n    axial_resistivities, but also coupling conductances.\n\n    This is done by first obtaining the current value of every parameter (not only\n    the trainable ones) and then replacing the trainable ones with the value\n    in `trainable_params()`. This function is run within `jx.integrate()`.\n\n    pstate can be obtained by calling `params_to_pstate()`.\n\n    .. code-block:: python\n\n        params = module.get_parameters() # i.e. [0, 1, 2]\n        pstate = params_to_pstate(params, module.indices_set_by_trainables)\n        module.to_jax() # needed for call to module.jaxnodes\n\n    Args:\n        pstate: The state of the trainable parameters. pstate takes the form\n            [{\n                \"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]),\n                \"val\": jnp.array([0.1, 0.2, 0.3])\n            }, ...].\n        voltage_solver: The voltage solver that is used. Since `jax.sparse` and\n            `jaxley.xyz` require different formats of the axial conductances, this\n            function will default to different building methods.\n\n    Returns:\n        A dictionary of all module parameters.\n    \"\"\"\n    params = {}\n    for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n        params[key] = self.base.jaxnodes[key]\n\n    for channel in self.base.channels:\n        for channel_params in channel.channel_params:\n            params[channel_params] = self.base.jaxnodes[channel_params]\n\n    for synapse_params in self.base.synapse_param_names:\n        params[synapse_params] = self.base.jaxedges[synapse_params]\n\n    # Override with those parameters set by `.make_trainable()`.\n    for parameter in pstate:\n        key = parameter[\"key\"]\n        inds = parameter[\"indices\"]\n        set_param = parameter[\"val\"]\n\n        # This is needed since SynapseViews worked differently before.\n        # This mimics the old behaviour and tranformes the new indices\n        # to the old indices.\n        # TODO FROM #447: Longterm this should be gotten rid of.\n        # Instead edges should work similar to nodes (would also allow for\n        # param sharing).\n        synapse_inds = self.base.edges.groupby(\"type\").rank()[\"global_edge_index\"]\n        synapse_inds = (synapse_inds.astype(int) - 1).to_numpy()\n        if key in self.base.synapse_param_names:\n            inds = synapse_inds[inds]\n\n        if key in params:  # Only parameters, not initial states.\n            # `inds` is of shape `(num_params, num_comps_per_param)`.\n            # `set_param` is of shape `(num_params,)`\n            # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n            # `.set()` to work. This is done with `[:, None]`.\n            params[key] = params[key].at[inds].set(set_param[:, None])\n\n    # Compute conductance params and add them to the params dictionary.\n    params[\"axial_conductances\"] = self.base._compute_axial_conductances(\n        params=params\n    )\n    return params\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_states","title":"get_all_states(pstate, all_params, delta_t)","text":"

Get the full initial state of the module from jaxnodes and trainables.

Parameters:

Name Type Description Default pstate List[Dict]

The state of the trainable parameters.

required all_params

All parameters of the module.

required delta_t float

The time step.

required

Returns:

Type Description Dict[str, ndarray]

A dictionary of all states of the module.

Source code in jaxley/modules/base.py
@only_allow_module\ndef get_all_states(\n    self, pstate: List[Dict], all_params, delta_t: float\n) -> Dict[str, jnp.ndarray]:\n    # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n    \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n    Args:\n        pstate: The state of the trainable parameters.\n        all_params: All parameters of the module.\n        delta_t: The time step.\n\n    Returns:\n        A dictionary of all states of the module.\n    \"\"\"\n    states = self.base._get_states_from_nodes_and_edges()\n\n    # Override with the initial states set by `.make_trainable()`.\n    for parameter in pstate:\n        key = parameter[\"key\"]\n        inds = parameter[\"indices\"]\n        set_param = parameter[\"val\"]\n        if key in states:  # Only initial states, not parameters.\n            # `inds` is of shape `(num_params, num_comps_per_param)`.\n            # `set_param` is of shape `(num_params,)`\n            # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n            # `.set()` to work. This is done with `[:, None]`.\n            states[key] = states[key].at[inds].set(set_param[:, None])\n\n    # Add to the states the initial current through every channel.\n    states, _ = self.base._channel_currents(\n        states, delta_t, self.channels, self.nodes, all_params\n    )\n\n    # Add to the states the initial current through every synapse.\n    states, _ = self.base._synapse_currents(\n        states, self.synapses, all_params, delta_t, self.edges\n    )\n    return states\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_parameters","title":"get_parameters()","text":"

Get all trainable parameters.

The returned parameters should be passed to `jx.integrate(\u2026, params=params).

Returns:

Type Description List[Dict[str, ndarray]]

A list of all trainable parameters in the form of [{\u201cgNa\u201d: jnp.array([0.1, 0.2, 0.3])}, \u2026].

Source code in jaxley/modules/base.py
def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n    \"\"\"Get all trainable parameters.\n\n    The returned parameters should be passed to `jx.integrate(..., params=params).\n\n    Returns:\n        A list of all trainable parameters in the form of\n            [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n    \"\"\"\n    return self.trainable_params\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.init_states","title":"init_states(delta_t=0.025)","text":"

Initialize all mechanisms in their steady state.

This considers the voltages and parameters of each compartment.

Parameters:

Name Type Description Default delta_t float

Passed on to channel.init_state().

0.025 Source code in jaxley/modules/base.py
@only_allow_module\ndef init_states(self, delta_t: float = 0.025):\n    # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n    \"\"\"Initialize all mechanisms in their steady state.\n\n    This considers the voltages and parameters of each compartment.\n\n    Args:\n        delta_t: Passed on to `channel.init_state()`.\n    \"\"\"\n    # Update states of the channels.\n    channel_nodes = self.base.nodes\n    states = self.base._get_states_from_nodes_and_edges()\n\n    # We do not use any `pstate` for initializing. In principle, we could change\n    # that by allowing an input `params` and `pstate` to this function.\n    # `voltage_solver` could also be `jax.sparse` here, because both of them\n    # build the channel parameters in the same way.\n    params = self.base.get_all_parameters([], voltage_solver=\"jaxley.thomas\")\n\n    for channel in self.base.channels:\n        name = channel._name\n        channel_indices = channel_nodes.loc[channel_nodes[name]][\n            \"global_comp_index\"\n        ].to_numpy()\n        voltages = channel_nodes.loc[channel_indices, \"v\"].to_numpy()\n\n        channel_param_names = list(channel.channel_params.keys())\n        channel_state_names = list(channel.channel_states.keys())\n        channel_states = query_channel_states_and_params(\n            states, channel_state_names, channel_indices\n        )\n        channel_params = query_channel_states_and_params(\n            params, channel_param_names, channel_indices\n        )\n\n        init_state = channel.init_state(\n            channel_states, voltages, channel_params, delta_t\n        )\n\n        # `init_state` might not return all channel states. Only the ones that are\n        # returned are updated here.\n        for key, val in init_state.items():\n            # Note that we are overriding `self.nodes` here, but `self.nodes` is\n            # not used above to actually compute the current states (so there are\n            # no issues with overriding states).\n            self.nodes.loc[channel_indices, key] = val\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.insert","title":"insert(channel)","text":"

Insert a channel into the module.

Parameters:

Name Type Description Default channel Channel

The channel to insert.

required Source code in jaxley/modules/base.py
def insert(self, channel: Channel):\n    \"\"\"Insert a channel into the module.\n\n    Args:\n        channel: The channel to insert.\"\"\"\n    name = channel._name\n\n    # Channel does not yet exist in the `jx.Module` at all.\n    if name not in [c._name for c in self.base.channels]:\n        self.base.channels.append(channel)\n        self.base.nodes[name] = (\n            False  # Previous columns do not have the new channel.\n        )\n\n    if channel.current_name not in self.base.membrane_current_names:\n        self.base.membrane_current_names.append(channel.current_name)\n\n    # Add a binary column that indicates if a channel is present.\n    self.base.nodes.loc[self._nodes_in_view, name] = True\n\n    # Loop over all new parameters, e.g. gNa, eNa.\n    for key in channel.channel_params:\n        self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key]\n\n    # Loop over all new parameters, e.g. gNa, eNa.\n    for key in channel.channel_states:\n        self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.loc","title":"loc(at)","text":"

Return a View of the module at the selected branch location(s).

Parameters:

Name Type Description Default at Any

location along the branch.

required

Returns:

Type Description View

View of the module at the specified branch location.

Source code in jaxley/modules/base.py
def loc(self, at: Any) -> View:\n    \"\"\"Return a View of the module at the selected branch location(s).\n\n    Args:\n        at: location along the branch.\n\n    Returns:\n        View of the module at the specified branch location.\"\"\"\n    global_comp_idxs = []\n    for i in self._branches_in_view:\n        ncomp = self.base.ncomp_per_branch[i]\n        comp_locs = np.linspace(0, 1, ncomp)\n        at = comp_locs if is_str_all(at) else self._reformat_index(at, dtype=float)\n        comp_edges = np.linspace(0, 1 + 1e-10, ncomp + 1)\n        idx = np.digitize(at, comp_edges) - 1 + self.base.cumsum_ncomp[i]\n        global_comp_idxs.append(idx)\n    global_comp_idxs = np.concatenate(global_comp_idxs)\n    orig_scope = self._scope\n    # global scope needed to select correct comps, for i.e. branches w. ncomp=[1,2]\n    # loc(0.9)  will correspond to different local branches (0 vs 1).\n    view = self.scope(\"global\").comp(global_comp_idxs).scope(orig_scope)\n    view._current_view = \"loc\"\n    return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.make_trainable","title":"make_trainable(key, init_val=None, verbose=True)","text":"

Make a parameter trainable.

If a parameter is made trainable, it will be returned by get_parameters() and should then be passed to jx.integrate(..., params=params).

Parameters:

Name Type Description Default key str

Name of the parameter to make trainable.

required init_val Optional[Union[float, list]]

Initial value of the parameter. If float, the same value is used for every created parameter. If list, the length of the list has to match the number of created parameters. If None, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters.

None verbose bool

Whether to print the number of parameters that are added and the total number of parameters.

True Source code in jaxley/modules/base.py
def make_trainable(\n    self,\n    key: str,\n    init_val: Optional[Union[float, list]] = None,\n    verbose: bool = True,\n):\n    \"\"\"Make a parameter trainable.\n\n    If a parameter is made trainable, it will be returned by `get_parameters()`\n    and should then be passed to `jx.integrate(..., params=params)`.\n\n    Args:\n        key: Name of the parameter to make trainable.\n        init_val: Initial value of the parameter. If `float`, the same value is\n            used for every created parameter. If `list`, the length of the list has\n            to match the number of created parameters. If `None`, the current\n            parameter value is used and if parameter sharing is performed that the\n            current parameter value is averaged over all shared parameters.\n        verbose: Whether to print the number of parameters that are added and the\n            total number of parameters.\n    \"\"\"\n    assert (\n        self.allow_make_trainable\n    ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n    ncomps_per_branch = (\n        self.base.nodes[\"global_branch_index\"].value_counts().to_numpy()\n    )\n    assert np.all(\n        ncomps_per_branch == ncomps_per_branch[0]\n    ), \"Parameter sharing is not allowed for modules containing branches with different numbers of compartments.\"\n\n    data = self.nodes if key in self.nodes.columns else None\n    data = self.edges if key in self.edges.columns else data\n\n    assert data is not None, f\"Key '{key}' not found in nodes or edges\"\n    not_nan = ~data[key].isna()\n    data = data.loc[not_nan]\n    assert (\n        len(data) > 0\n    ), \"No settable parameters found in the selected compartments.\"\n\n    grouped_view = data.groupby(\"controlled_by_param\")\n    # Because of this `x.index.values` we cannot support `make_trainable()` on\n    # the module level for synapse parameters (but only for `SynapseView`).\n    inds_of_comps = list(\n        grouped_view.apply(lambda x: x.index.values, include_groups=False)\n    )\n    indices_per_param = jnp.stack(inds_of_comps)\n    # Sorted inds are only used to infer the correct starting values.\n    param_vals = jnp.asarray(\n        [data.loc[inds, key].to_numpy() for inds in inds_of_comps]\n    )\n\n    # Set the value which the trainable parameter should take.\n    num_created_parameters = len(indices_per_param)\n    if init_val is not None:\n        if isinstance(init_val, float):\n            new_params = jnp.asarray([init_val] * num_created_parameters)\n        elif isinstance(init_val, list):\n            assert (\n                len(init_val) == num_created_parameters\n            ), f\"len(init_val)={len(init_val)}, but trying to create {num_created_parameters} parameters.\"\n            new_params = jnp.asarray(init_val)\n        else:\n            raise ValueError(\n                f\"init_val must a float, list, or None, but it is a {type(init_val).__name__}.\"\n            )\n    else:\n        new_params = jnp.mean(param_vals, axis=1)\n    self.base.trainable_params.append({key: new_params})\n    self.base.indices_set_by_trainables.append(indices_per_param)\n    self.base.num_trainable_params += num_created_parameters\n    if verbose:\n        print(\n            f\"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}\"\n        )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.move","title":"move(x=0.0, y=0.0, z=0.0, update_nodes=False)","text":"

Move cells or networks by adding to their (x, y, z) coordinates.

This function is used only for visualization. It does not affect the simulation.

Parameters:

Name Type Description Default x float

The amount to move in the x direction in um.

0.0 y float

The amount to move in the y direction in um.

0.0 z float

The amount to move in the z direction in um.

0.0 update_nodes bool

Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

False Source code in jaxley/modules/base.py
def move(\n    self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = False\n):\n    \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n    This function is used only for visualization. It does not affect the simulation.\n\n    Args:\n        x: The amount to move in the x direction in um.\n        y: The amount to move in the y direction in um.\n        z: The amount to move in the z direction in um.\n        update_nodes: Whether `.nodes` should be updated or not. Setting this to\n            `False` largely speeds up moving, especially for big networks, but\n            `.nodes` or `.show` will not show the new xyz coordinates.\n    \"\"\"\n    for i in self._branches_in_view:\n        self.base.xyzr[i][:, :3] += np.array([x, y, z])\n    if update_nodes:\n        self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.move_to","title":"move_to(x=0.0, y=0.0, z=0.0, update_nodes=False)","text":"

Move cells or networks to a location (x, y, z).

If x, y, and z are floats, then the first compartment of the first branch of the first cell is moved to that float coordinate, and everything else is shifted by the difference between that compartment\u2019s previous coordinate and the new float location.

If x, y, and z are arrays, then they must each have a length equal to the number of cells being moved. Then the first compartment of the first branch of each cell is moved to the specified location.

Parameters:

Name Type Description Default update_nodes bool

Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

False Source code in jaxley/modules/base.py
def move_to(\n    self,\n    x: Union[float, np.ndarray] = 0.0,\n    y: Union[float, np.ndarray] = 0.0,\n    z: Union[float, np.ndarray] = 0.0,\n    update_nodes: bool = False,\n):\n    \"\"\"Move cells or networks to a location (x, y, z).\n\n    If x, y, and z are floats, then the first compartment of the first branch\n    of the first cell is moved to that float coordinate, and everything else is\n    shifted by the difference between that compartment's previous coordinate and\n    the new float location.\n\n    If x, y, and z are arrays, then they must each have a length equal to the number\n    of cells being moved. Then the first compartment of the first branch of each\n    cell is moved to the specified location.\n\n    Args:\n        update_nodes: Whether `.nodes` should be updated or not. Setting this to\n            `False` largely speeds up moving, especially for big networks, but\n            `.nodes` or `.show` will not show the new xyz coordinates.\n    \"\"\"\n    # Test if any coordinate values are NaN which would greatly affect moving\n    if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n        raise ValueError(\n            \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n        )\n\n    # can only iterate over cells for networks\n    # lambda makes sure that generator can be created multiple times\n    base_is_net = self.base._current_view == \"network\"\n    cells = lambda: (self.cells if base_is_net else [self])\n\n    root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()])\n    root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells\n    move_by = np.array([x, y, z]).T - root_xyz\n\n    if len(move_by.shape) == 1:\n        move_by = np.tile(move_by, (len(self._cells_in_view), 1))\n\n    for cell, offset in zip(cells(), move_by):\n        for idx in cell._branches_in_view:\n            self.base.xyzr[idx][:, :3] += offset\n    if update_nodes:\n        self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.rotate","title":"rotate(degrees, rotation_axis='xy', update_nodes=False)","text":"

Rotate jaxley modules clockwise. Used only for visualization.

This function is used only for visualization. It does not affect the simulation.

Parameters:

Name Type Description Default degrees float

How many degrees to rotate the module by.

required rotation_axis str

Either of {xy | xz | yz}.

'xy' Source code in jaxley/modules/base.py
def rotate(\n    self, degrees: float, rotation_axis: str = \"xy\", update_nodes: bool = False\n):\n    \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n    This function is used only for visualization. It does not affect the simulation.\n\n    Args:\n        degrees: How many degrees to rotate the module by.\n        rotation_axis: Either of {`xy` | `xz` | `yz`}.\n    \"\"\"\n    degrees = degrees / 180 * np.pi\n    if rotation_axis == \"xy\":\n        dims = [0, 1]\n    elif rotation_axis == \"xz\":\n        dims = [0, 2]\n    elif rotation_axis == \"yz\":\n        dims = [1, 2]\n    else:\n        raise ValueError\n\n    rotation_matrix = np.asarray(\n        [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]\n    )\n    for i in self._branches_in_view:\n        rot = np.dot(rotation_matrix, self.base.xyzr[i][:, dims].T).T\n        self.base.xyzr[i][:, dims] = rot\n    if update_nodes:\n        self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.scope","title":"scope(scope)","text":"

Return a View of the module with the specified scope.

For example cell.scope(\"global\").branch(2).scope(\"local\").comp(1) will return the 1st compartment of branch 2.

Parameters:

Name Type Description Default scope str

either \u201cglobal\u201d or \u201clocal\u201d.

required

Returns:

Type Description View

View with the specified scope.

Source code in jaxley/modules/base.py
def scope(self, scope: str) -> View:\n    \"\"\"Return a View of the module with the specified scope.\n\n    For example `cell.scope(\"global\").branch(2).scope(\"local\").comp(1)`\n    will return the 1st compartment of branch 2.\n\n    Args:\n        scope: either \"global\" or \"local\".\n\n    Returns:\n        View with the specified scope.\"\"\"\n    view = self.view\n    view.set_scope(scope)\n    return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.select","title":"select(nodes=None, edges=None, sorted=False)","text":"

Return View of the module filtered by specific node or edges indices.

Parameters:

Name Type Description Default nodes ndarray

indices of nodes to view. If None, all nodes are viewed.

None edges ndarray

indices of edges to view. If None, all edges are viewed.

None sorted bool

if True, nodes and edges are sorted.

False

Returns:

Type Description View

View for subset of selected nodes and/or edges.

Source code in jaxley/modules/base.py
def select(\n    self, nodes: np.ndarray = None, edges: np.ndarray = None, sorted: bool = False\n) -> View:\n    \"\"\"Return View of the module filtered by specific node or edges indices.\n\n    Args:\n        nodes: indices of nodes to view. If None, all nodes are viewed.\n        edges: indices of edges to view. If None, all edges are viewed.\n        sorted: if True, nodes and edges are sorted.\n\n    Returns:\n        View for subset of selected nodes and/or edges.\"\"\"\n\n    nodes = self._reformat_index(nodes) if nodes is not None else None\n    nodes = self._nodes_in_view if is_str_all(nodes) else nodes\n    nodes = np.sort(nodes) if sorted else nodes\n\n    edges = self._reformat_index(edges) if edges is not None else None\n    edges = self._edges_in_view if is_str_all(edges) else edges\n    edges = np.sort(edges) if sorted else edges\n\n    view = View(self, nodes, edges)\n    view._set_controlled_by_param(\"filter\")\n    return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set","title":"set(key, val)","text":"

Set parameter of module (or its view) to a new value.

Note that this function can not be called within jax.jit or jax.grad. Instead, it should be used set the parameters of the module before the simulation. Use .data_set() to set parameters during jax.jit or jax.grad.

Parameters:

Name Type Description Default key str

The name of the parameter to set.

required val Union[float, ndarray]

The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).

required Source code in jaxley/modules/base.py
def set(self, key: str, val: Union[float, jnp.ndarray]):\n    \"\"\"Set parameter of module (or its view) to a new value.\n\n    Note that this function can not be called within `jax.jit` or `jax.grad`.\n    Instead, it should be used set the parameters of the module **before** the\n    simulation. Use `.data_set()` to set parameters during `jax.jit` or\n    `jax.grad`.\n\n    Args:\n        key: The name of the parameter to set.\n        val: The value to set the parameter to. If it is `jnp.ndarray` then it\n            must be of shape `(len(num_compartments))`.\n    \"\"\"\n    if key in self.nodes.columns:\n        not_nan = ~self.nodes[key].isna().to_numpy()\n        self.base.nodes.loc[self._nodes_in_view[not_nan], key] = val\n    elif key in self.edges.columns:\n        not_nan = ~self.edges[key].isna().to_numpy()\n        self.base.edges.loc[self._edges_in_view[not_nan], key] = val\n    else:\n        raise KeyError(f\"Key '{key}' not found in nodes or edges\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set_ncomp","title":"set_ncomp(ncomp, min_radius=None)","text":"

Set the number of compartments with which the branch is discretized.

Parameters:

Name Type Description Default ncomp int

The number of compartments that the branch should be discretized into.

required min_radius Optional[float]

Only used if the morphology was read from an SWC file. If passed the radius is capped to be at least this value.

None Source code in jaxley/modules/base.py
def set_ncomp(\n    self,\n    ncomp: int,\n    min_radius: Optional[float] = None,\n):\n    \"\"\"Set the number of compartments with which the branch is discretized.\n\n    Args:\n        ncomp: The number of compartments that the branch should be discretized\n            into.\n        min_radius: Only used if the morphology was read from an SWC file. If passed\n            the radius is capped to be at least this value.\n\n    Raises:\n        - When there are stimuli in any compartment in the module.\n        - When there are recordings in any compartment in the module.\n        - When the channels of the compartments are not the same within the branch\n        that is modified.\n        - When the lengths of the compartments are not the same within the branch\n        that is modified.\n        - Unless the morphology was read from an SWC file, when the radiuses of the\n        compartments are not the same within the branch that is modified.\n    \"\"\"\n    assert len(self.base.externals) == 0, \"No stimuli allowed!\"\n    assert len(self.base.recordings) == 0, \"No recordings allowed!\"\n    assert len(self.base.trainable_params) == 0, \"No trainables allowed!\"\n\n    assert self.base._module_type != \"network\", \"This is not allowed for networks.\"\n    assert not (\n        self.base._module_type == \"cell\"\n        and len(self._branches_in_view) == len(self.base._branches_in_view)\n    ), \"This is not allowed for cells.\"\n\n    # Update all attributes that are affected by compartment structure.\n    view = self.nodes.copy()\n    all_nodes = self.base.nodes\n    start_idx = self.nodes[\"global_comp_index\"].to_numpy()[0]\n    ncomp_per_branch = self.base.ncomp_per_branch\n    channel_names = [c._name for c in self.base.channels]\n    channel_param_names = list(\n        chain(*[c.channel_params for c in self.base.channels])\n    )\n    channel_state_names = list(\n        chain(*[c.channel_states for c in self.base.channels])\n    )\n    radius_generating_fns = self.base._radius_generating_fns\n\n    within_branch_radiuses = view[\"radius\"].to_numpy()\n    compartment_lengths = view[\"length\"].to_numpy()\n    num_previous_ncomp = len(within_branch_radiuses)\n    branch_indices = pd.unique(view[\"global_branch_index\"])\n\n    error_msg = lambda name: (\n        f\"You previously modified the {name} of individual compartments, but \"\n        f\"now you are modifying the number of compartments in this branch. \"\n        f\"This is not allowed. First build the morphology with `set_ncomp()` and \"\n        f\"then modify the radiuses and lengths of compartments.\"\n    )\n\n    if (\n        ~np.all(within_branch_radiuses == within_branch_radiuses[0])\n        and radius_generating_fns is None\n    ):\n        raise ValueError(error_msg(\"radius\"))\n\n    for property_name in [\"length\", \"capacitance\", \"axial_resistivity\"]:\n        compartment_properties = view[property_name].to_numpy()\n        if ~np.all(compartment_properties == compartment_properties[0]):\n            raise ValueError(error_msg(property_name))\n\n    if not (self.nodes[channel_names].var() == 0.0).all():\n        raise ValueError(\n            \"Some channel exists only in some compartments of the branch which you\"\n            \"are trying to modify. This is not allowed. First specify the number\"\n            \"of compartments with `.set_ncomp()` and then insert the channels\"\n            \"accordingly.\"\n        )\n\n    if not (\n        self.nodes[channel_param_names + channel_state_names].var() == 0.0\n    ).all():\n        raise ValueError(\n            \"Some channel has different parameters or states between the \"\n            \"different compartments of the branch which you are trying to modify. \"\n            \"This is not allowed. First specify the number of compartments with \"\n            \"`.set_ncomp()` and then insert the channels accordingly.\"\n        )\n\n    # Add new rows as the average of all rows. Special case for the length is below.\n    average_row = self.nodes.mean(skipna=False)\n    average_row = average_row.to_frame().T\n    view = pd.concat([*[average_row] * ncomp], axis=\"rows\")\n\n    # Set the correct datatype after having performed an average which cast\n    # everything to float.\n    integer_cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n    view[integer_cols] = view[integer_cols].astype(int)\n\n    # Whether or not a channel exists in a compartment is a boolean.\n    boolean_cols = channel_names\n    view[boolean_cols] = view[boolean_cols].astype(bool)\n\n    # Special treatment for the lengths and radiuses. These are not being set as\n    # the average because we:\n    # 1) Want to maintain the total length of a branch.\n    # 2) Want to use the SWC inferred radius.\n    #\n    # Compute new compartment lengths.\n    comp_lengths = np.sum(compartment_lengths) / ncomp\n    view[\"length\"] = comp_lengths\n\n    # Compute new compartment radiuses.\n    if radius_generating_fns is not None:\n        view[\"radius\"] = build_radiuses_from_xyzr(\n            radius_fns=radius_generating_fns,\n            branch_indices=branch_indices,\n            min_radius=min_radius,\n            ncomp=ncomp,\n        )\n    else:\n        view[\"radius\"] = within_branch_radiuses[0] * np.ones(ncomp)\n\n    # Update `.nodes`.\n    # 1) Delete N rows starting from start_idx\n    number_deleted = num_previous_ncomp\n    all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted))\n\n    # 2) Insert M new rows at the same location\n    df1 = all_nodes.iloc[:start_idx]  # Rows before the insertion point\n    df2 = all_nodes.iloc[start_idx:]  # Rows after the insertion point\n\n    # 3) Combine the parts: before, new rows, and after\n    all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True)\n\n    # Override `comp_index` to just be a consecutive list.\n    all_nodes[\"global_comp_index\"] = np.arange(len(all_nodes))\n\n    # Update compartment structure arguments.\n    ncomp_per_branch[branch_indices] = ncomp\n    ncomp = int(np.max(ncomp_per_branch))\n    cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n    internal_node_inds = np.arange(cumsum_ncomp[-1])\n\n    self.base.nodes = all_nodes\n    self.base.ncomp_per_branch = ncomp_per_branch\n    self.base.ncomp = ncomp\n    self.base.cumsum_ncomp = cumsum_ncomp\n    self.base._internal_node_inds = internal_node_inds\n\n    # Update the morphology indexing (e.g., `.comp_edges`).\n    self.base._initialize()\n    self.base._init_view()\n    self.base._update_local_indices()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set_scope","title":"set_scope(scope)","text":"

Toggle between \u201cglobal\u201d or \u201clocal\u201d scope.

Determines if global or local indices are used for viewing the module.

Parameters:

Name Type Description Default scope str

either \u201cglobal\u201d or \u201clocal\u201d.

required Source code in jaxley/modules/base.py
def set_scope(self, scope: str):\n    \"\"\"Toggle between \"global\" or \"local\" scope.\n\n    Determines if global or local indices are used for viewing the module.\n\n    Args:\n        scope: either \"global\" or \"local\".\"\"\"\n    assert scope in [\"global\", \"local\"], \"Invalid scope.\"\n    self._scope = scope\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.show","title":"show(param_names=None, *, indices=True, params=True, states=True, channel_names=None)","text":"

Print detailed information about the Module or a view of it.

Parameters:

Name Type Description Default param_names Optional[Union[str, List[str]]]

The names of the parameters to show. If None, all parameters are shown.

None indices bool

Whether to show the indices of the compartments.

True params bool

Whether to show the parameters of the compartments.

True states bool

Whether to show the states of the compartments.

True channel_names Optional[List[str]]

The names of the channels to show. If None, all channels are shown.

None

Returns:

Type Description DataFrame

A pd.DataFrame with the requested information.

Source code in jaxley/modules/base.py
def show(\n    self,\n    param_names: Optional[Union[str, List[str]]] = None,\n    *,\n    indices: bool = True,\n    params: bool = True,\n    states: bool = True,\n    channel_names: Optional[List[str]] = None,\n) -> pd.DataFrame:\n    \"\"\"Print detailed information about the Module or a view of it.\n\n    Args:\n        param_names: The names of the parameters to show. If `None`, all parameters\n            are shown.\n        indices: Whether to show the indices of the compartments.\n        params: Whether to show the parameters of the compartments.\n        states: Whether to show the states of the compartments.\n        channel_names: The names of the channels to show. If `None`, all channels are\n            shown.\n\n    Returns:\n        A `pd.DataFrame` with the requested information.\n    \"\"\"\n    nodes = self.nodes.copy()  # prevents this from being edited\n\n    cols = []\n    inds = [\"comp_index\", \"branch_index\", \"cell_index\"]\n    scopes = [\"local\", \"global\"]\n    inds = [f\"{s}_{i}\" for i in inds for s in scopes] if indices else []\n    cols += inds\n    cols += [ch._name for ch in self.channels] if channel_names else []\n    cols += (\n        sum([list(ch.channel_params) for ch in self.channels], []) if params else []\n    )\n    cols += (\n        sum([list(ch.channel_states) for ch in self.channels], []) if states else []\n    )\n\n    if not param_names is None:\n        cols = (\n            inds + [c for c in cols if c in param_names]\n            if params\n            else list(param_names)\n        )\n\n    return nodes[cols]\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.step","title":"step(u, delta_t, external_inds, externals, params, solver='bwd_euler', voltage_solver='jaxley.stone')","text":"

One step of solving the Ordinary Differential Equation.

This function is called inside of integrate and increments the state of the module by one time step. Calls _step_channels and _step_synapse to update the states of the channels and synapses using fwd_euler.

Parameters:

Name Type Description Default u Dict[str, ndarray]

The state of the module. voltages = u[\u201cv\u201d]

required delta_t float

The time step.

required external_inds Dict[str, ndarray]

The indices of the external inputs.

required externals Dict[str, ndarray]

The external inputs.

required params Dict[str, ndarray]

The parameters of the module.

required solver str

The solver to use for the voltages. Either of [\u201cbwd_euler\u201d, \u201cfwd_euler\u201d, \u201ccrank_nicolson\u201d].

'bwd_euler' voltage_solver str

The tridiagonal solver used to diagonalize the coefficient matrix of the ODE system. Either of [\u201cjaxley.thomas\u201d, \u201cjaxley.stone\u201d].

'jaxley.stone'

Returns:

Type Description Dict[str, ndarray]

The updated state of the module.

Source code in jaxley/modules/base.py
@only_allow_module\ndef step(\n    self,\n    u: Dict[str, jnp.ndarray],\n    delta_t: float,\n    external_inds: Dict[str, jnp.ndarray],\n    externals: Dict[str, jnp.ndarray],\n    params: Dict[str, jnp.ndarray],\n    solver: str = \"bwd_euler\",\n    voltage_solver: str = \"jaxley.stone\",\n) -> Dict[str, jnp.ndarray]:\n    \"\"\"One step of solving the Ordinary Differential Equation.\n\n    This function is called inside of `integrate` and increments the state of the\n    module by one time step. Calls `_step_channels` and `_step_synapse` to update\n    the states of the channels and synapses using fwd_euler.\n\n    Args:\n        u: The state of the module. voltages = u[\"v\"]\n        delta_t: The time step.\n        external_inds: The indices of the external inputs.\n        externals: The external inputs.\n        params: The parameters of the module.\n        solver: The solver to use for the voltages. Either of [\"bwd_euler\",\n            \"fwd_euler\", \"crank_nicolson\"].\n        voltage_solver: The tridiagonal solver used to diagonalize the\n            coefficient matrix of the ODE system. Either of [\"jaxley.thomas\",\n            \"jaxley.stone\"].\n\n    Returns:\n        The updated state of the module.\n    \"\"\"\n\n    # Extract the voltages\n    voltages = u[\"v\"]\n\n    # Extract the external inputs\n    if \"i\" in externals.keys():\n        i_current = externals[\"i\"]\n        i_inds = external_inds[\"i\"]\n        i_ext = self._get_external_input(\n            voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n        )\n    else:\n        i_ext = 0.0\n\n    # Step of the channels.\n    u, (v_terms, const_terms) = self._step_channels(\n        u, delta_t, self.channels, self.nodes, params\n    )\n\n    # Step of the synapse.\n    u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n        u,\n        self.synapses,\n        params,\n        delta_t,\n        self.edges,\n    )\n\n    # Clamp for channels and synapses.\n    for key in externals.keys():\n        if key not in [\"i\", \"v\"]:\n            u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n    # Voltage steps.\n    cm = params[\"capacitance\"]  # Abbreviation.\n\n    # Arguments used by all solvers.\n    solver_kwargs = {\n        \"voltages\": voltages,\n        \"voltage_terms\": (v_terms + syn_v_terms) / cm,\n        \"constant_terms\": (const_terms + i_ext + syn_const_terms) / cm,\n        \"axial_conductances\": params[\"axial_conductances\"],\n        \"internal_node_inds\": self._internal_node_inds,\n    }\n\n    # Add solver specific arguments.\n    if voltage_solver == \"jax.sparse\":\n        solver_kwargs.update(\n            {\n                \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n                \"data_inds\": self._data_inds,\n                \"indices\": self._indices_jax_spsolve,\n                \"indptr\": self._indptr_jax_spsolve,\n                \"n_nodes\": self._n_nodes,\n            }\n        )\n        # Only for `bwd_euler` and `cranck-nicolson`.\n        step_voltage_implicit = step_voltage_implicit_with_jax_spsolve\n    else:\n        # Our custom sparse solver requires a different format of all conductance\n        # values to perform triangulation and backsubstution optimally.\n        #\n        # Currently, the forward Euler solver also uses this format. However,\n        # this is only for historical reasons and we are planning to change this in\n        # the future.\n        solver_kwargs.update(\n            {\n                \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n                \"sources\": np.asarray(self._comp_edges[\"source\"].to_list()),\n                \"types\": np.asarray(self._comp_edges[\"type\"].to_list()),\n                \"ncomp_per_branch\": self.ncomp_per_branch,\n                \"par_inds\": self._par_inds,\n                \"child_inds\": self._child_inds,\n                \"nbranches\": self.total_nbranches,\n                \"solver\": voltage_solver,\n                \"idx\": self._solve_indexer,\n                \"debug_states\": self.debug_states,\n            }\n        )\n        # Only for `bwd_euler` and `cranck-nicolson`.\n        step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve\n\n    if solver == \"bwd_euler\":\n        u[\"v\"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)\n    elif solver == \"crank_nicolson\":\n        # Crank-Nicolson advances by half a step of backward and half a step of\n        # forward Euler.\n        half_step_delta_t = delta_t / 2\n        half_step_voltages = step_voltage_implicit(\n            **solver_kwargs, delta_t=half_step_delta_t\n        )\n        # The forward Euler step in Crank-Nicolson can be performed easily as\n        # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4.\n        u[\"v\"] = 2 * half_step_voltages - voltages\n    elif solver == \"fwd_euler\":\n        u[\"v\"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t)\n    else:\n        raise ValueError(\n            f\"You specified `solver={solver}`. The only allowed solvers are \"\n            \"['bwd_euler', 'fwd_euler', 'crank_nicolson'].\"\n        )\n\n    # Clamp for voltages.\n    if \"v\" in externals.keys():\n        u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n    return u\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.stimulate","title":"stimulate(current=None, verbose=True)","text":"

Insert a stimulus into the compartment.

current must be a 1d array or have batch dimension of size (num_compartments, ) or (1, ). If 1d, the same stimulus is added to all compartments.

This function cannot be run during jax.jit and jax.grad. Because of this, it should only be used for static stimuli (i.e., stimuli that do not depend on the data and that should not be learned). For stimuli that depend on data (or that should be learned), please use data_stimulate().

Parameters:

Name Type Description Default current Optional[ndarray]

Current in nA.

None Source code in jaxley/modules/base.py
def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n    \"\"\"Insert a stimulus into the compartment.\n\n    current must be a 1d array or have batch dimension of size `(num_compartments, )`\n    or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n    This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n    it should only be used for static stimuli (i.e., stimuli that do not depend\n    on the data and that should not be learned). For stimuli that depend on data\n    (or that should be learned), please use `data_stimulate()`.\n\n    Args:\n        current: Current in `nA`.\n    \"\"\"\n    self._external_input(\"i\", current, verbose=verbose)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.to_jax","title":"to_jax()","text":"

Move .nodes to .jaxnodes.

Before the actual simulation is run (via jx.integrate), all parameters of the jx.Module are stored in .nodes (a pd.DataFrame). However, for simulation, these parameters have to be moved to be jnp.ndarrays such that they can be processed on GPU/TPU and such that the simulation can be differentiated. .to_jax() copies the .nodes to .jaxnodes.

Source code in jaxley/modules/base.py
@only_allow_module\ndef to_jax(self):\n    # TODO FROM #447: Make this work for View?\n    \"\"\"Move `.nodes` to `.jaxnodes`.\n\n    Before the actual simulation is run (via `jx.integrate`), all parameters of\n    the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n    simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n    they can be processed on GPU/TPU and such that the simulation can be\n    differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n    \"\"\"\n    self.base.jaxnodes = {}\n    for key, value in self.base.nodes.to_dict(orient=\"list\").items():\n        inds = jnp.arange(len(value))\n        self.base.jaxnodes[key] = jnp.asarray(value)[inds]\n\n    # `jaxedges` contains only parameters (no indices).\n    # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n    # we allow parameter sharing.\n    self.base.jaxedges = {}\n    edges = self.base.edges.to_dict(orient=\"list\")\n    for i, synapse in enumerate(self.base.synapses):\n        condition = np.asarray(edges[\"type_ind\"]) == i\n        for key in synapse.synapse_params:\n            self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n        for key in synapse.synapse_states:\n            self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.vis","title":"vis(ax=None, color='k', dims=(0, 1), type='line', **kwargs)","text":"

Visualize the module.

Modules can be visualized on one of the cardinal planes (xy, xz, yz) or even in 3D.

Several options are available: - line: All points from the traced morphology (xyzr), are connected with a line plot. - scatter: All traced points, are plotted as scatter points. - comp: Plots the compartmentalized morphology, including radius and shape. (shows the true compartment lengths per default, but this can be changed via the kwargs, for details see jaxley.utils.plot_utils.plot_comps). - morph: Reconstructs the 3D shape of the traced morphology. For details see jaxley.utils.plot_utils.plot_morph. Warning: For 3D plots and morphologies with many traced points this can be very slow.

Parameters:

Name Type Description Default ax Optional[Axes]

An axis into which to plot.

None color str

The color for all branches.

'k' dims Tuple[int]

Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

(0, 1) type str

The type of plot. One of [\u201cline\u201d, \u201cscatter\u201d, \u201ccomp\u201d, \u201cmorph\u201d].

'line' kwargs

Keyword arguments passed to the plotting function.

{} Source code in jaxley/modules/base.py
def vis(\n    self,\n    ax: Optional[Axes] = None,\n    color: str = \"k\",\n    dims: Tuple[int] = (0, 1),\n    type: str = \"line\",\n    **kwargs,\n) -> Axes:\n    \"\"\"Visualize the module.\n\n    Modules can be visualized on one of the cardinal planes (xy, xz, yz) or\n    even in 3D.\n\n    Several options are available:\n    - `line`: All points from the traced morphology (`xyzr`), are connected\n    with a line plot.\n    - `scatter`: All traced points, are plotted as scatter points.\n    - `comp`: Plots the compartmentalized morphology, including radius\n    and shape. (shows the true compartment lengths per default, but this can\n    be changed via the `kwargs`, for details see\n    `jaxley.utils.plot_utils.plot_comps`).\n    - `morph`: Reconstructs the 3D shape of the traced morphology. For details see\n    `jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies\n    with many traced points this can be very slow.\n\n    Args:\n        ax: An axis into which to plot.\n        color: The color for all branches.\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two of them.\n        type: The type of plot. One of [\"line\", \"scatter\", \"comp\", \"morph\"].\n        kwargs: Keyword arguments passed to the plotting function.\n    \"\"\"\n    res = 100 if \"resolution\" not in kwargs else kwargs.pop(\"resolution\")\n    if \"comp\" in type.lower():\n        return plot_comps(\n            self, dims=dims, ax=ax, color=color, resolution=res, **kwargs\n        )\n    if \"morph\" in type.lower():\n        return plot_morph(\n            self, dims=dims, ax=ax, color=color, resolution=res, **kwargs\n        )\n\n    assert not np.any(\n        [np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]\n    ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n\n    ax = plot_graph(\n        self.xyzr,\n        dims=dims,\n        color=color,\n        ax=ax,\n        type=type,\n        **kwargs,\n    )\n\n    return ax\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.write_trainables","title":"write_trainables(trainable_params)","text":"

Write the trainables into .nodes and .edges.

This allows to, e.g., visualize trained networks with .vis().

Parameters:

Name Type Description Default trainable_params List[Dict[str, ndarray]]

The trainable parameters returned by get_parameters().

required Source code in jaxley/modules/base.py
def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):\n    \"\"\"Write the trainables into `.nodes` and `.edges`.\n\n    This allows to, e.g., visualize trained networks with `.vis()`.\n\n    Args:\n        trainable_params: The trainable parameters returned by `get_parameters()`.\n    \"\"\"\n    # We do not support views. Why? `jaxedges` does not have any NaN\n    # elements, whereas edges does. Because of this, we already need special\n    # treatment to make this function work, and it would be an even bigger hassle\n    # if we wanted to support this.\n    assert self.__class__.__name__ in [\n        \"Compartment\",\n        \"Branch\",\n        \"Cell\",\n        \"Network\",\n    ], \"Only supports modules.\"\n\n    # We could also implement this without casting the module to jax.\n    # However, I think it allows us to reuse as much code as possible and it avoids\n    # any kind of issues with indexing or parameter sharing (as this is fully\n    # taken care of by `get_all_parameters()`).\n    self.base.to_jax()\n    pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)\n    all_params = self.base.get_all_parameters(pstate, voltage_solver=\"jaxley.stone\")\n\n    # The value for `delta_t` does not matter here because it is only used to\n    # compute the initial current. However, the initial current cannot be made\n    # trainable and so its value never gets used below.\n    all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025)\n\n    # Loop only over the keys in `pstate` to avoid unnecessary computation.\n    for parameter in pstate:\n        key = parameter[\"key\"]\n        if key in self.base.nodes.columns:\n            vals_to_set = all_params if key in all_params.keys() else all_states\n            self.base.nodes[key] = vals_to_set[key]\n\n    # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n    # we allow parameter sharing.\n    edges = self.base.edges.to_dict(orient=\"list\")\n    for i, synapse in enumerate(self.base.synapses):\n        condition = np.asarray(edges[\"type_ind\"]) == i\n        for key in list(synapse.synapse_params.keys()):\n            self.base.edges.loc[condition, key] = all_params[key]\n        for key in list(synapse.synapse_states.keys()):\n            self.base.edges.loc[condition, key] = all_states[key]\n
"},{"location":"reference/modules/#compartment","title":"Compartment","text":"

Bases: Module

Compartment class.

This class defines a single compartment that can be simulated by itself or connected up into branches. It is the basic building block of a neuron model.

Source code in jaxley/modules/compartment.py
class Compartment(Module):\n    \"\"\"Compartment class.\n\n    This class defines a single compartment that can be simulated by itself or\n    connected up into branches. It is the basic building block of a neuron model.\n    \"\"\"\n\n    compartment_params: Dict = {\n        \"length\": 10.0,  # um\n        \"radius\": 1.0,  # um\n        \"axial_resistivity\": 5_000.0,  # ohm cm\n        \"capacitance\": 1.0,  # uF/cm^2\n    }\n    compartment_states: Dict = {\"v\": -70.0}\n\n    def __init__(self):\n        super().__init__()\n\n        self.ncomp = 1\n        self.ncomp_per_branch = np.asarray([1])\n        self.total_nbranches = 1\n        self.nbranches_per_cell = [1]\n        self._cumsum_nbranches = np.asarray([0, 1])\n        self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n        # Setting up the `nodes` for indexing.\n        self.nodes = pd.DataFrame(\n            dict(global_cell_index=[0], global_branch_index=[0], global_comp_index=[0])\n        )\n        self._append_params_and_states(self.compartment_params, self.compartment_states)\n        self._update_local_indices()\n        self._init_view()\n\n        # Synapses.\n        self.branch_edges = pd.DataFrame(\n            dict(parent_branch_index=[], child_branch_index=[])\n        )\n\n        # For morphology indexing.\n        self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n            compute_children_and_parents(self.branch_edges)\n        )\n        self._internal_node_inds = jnp.asarray([0])\n\n        # Initialize the module.\n        self._initialize()\n\n        # Coordinates.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n    def _init_morph_jaxley_spsolve(self):\n        self._solve_indexer = JaxleySolveIndexer(\n            cumsum_ncomp=self.cumsum_ncomp,\n            branchpoint_group_inds=np.asarray([]).astype(int),\n            children_in_level=[],\n            parents_in_level=[],\n            root_inds=np.asarray([0]),\n            remapped_node_indices=self._internal_node_inds,\n        )\n\n    def _init_morph_jax_spsolve(self):\n        \"\"\"Initialize morphology for the jax sparse voltage solver.\n\n        Explanation of `self._comp_eges['type']`:\n        `type == 0`: compartment <--> compartment (within branch)\n        `type == 1`: branchpoint --> parent-compartment\n        `type == 2`: branchpoint --> child-compartment\n        `type == 3`: parent-compartment --> branchpoint\n        `type == 4`: child-compartment --> branchpoint\n        \"\"\"\n        self._comp_edges = pd.DataFrame().from_dict(\n            {\"source\": [], \"sink\": [], \"type\": []}\n        )\n        n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n        self._n_nodes = n_nodes\n        self._data_inds = data_inds\n        self._indices_jax_spsolve = indices\n        self._indptr_jax_spsolve = indptr\n
"},{"location":"reference/modules/#branch","title":"Branch","text":"

Bases: Module

Branch class.

This class defines a single branch that can be simulated by itself or connected to build a cell. A branch is linear segment of several compartments and can be connected to no, one or more other branches at each end to build more intricate cell morphologies.

Source code in jaxley/modules/branch.py
class Branch(Module):\n    \"\"\"Branch class.\n\n    This class defines a single branch that can be simulated by itself or\n    connected to build a cell. A branch is linear segment of several compartments\n    and can be connected to no, one or more other branches at each end to build more\n    intricate cell morphologies.\n    \"\"\"\n\n    branch_params: Dict = {}\n    branch_states: Dict = {}\n\n    @deprecated_kwargs(\"0.6.0\", [\"nseg\"])\n    def __init__(\n        self,\n        compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n        ncomp: Optional[int] = None,\n        nseg: Optional[int] = None,\n    ):\n        \"\"\"\n        Args:\n            compartments: A single compartment or a list of compartments that make up the\n                branch.\n            ncomp: Number of segments to divide the branch into. If `compartments` is an\n                a single compartment, than the compartment is repeated `ncomp` times to\n                create the branch.\n        \"\"\"\n        # Warnings and errors that deal with the change from `nseg` to `ncomp` change\n        # in Jaxley v0.5.0.\n        if ncomp is not None and nseg is not None:\n            raise ValueError(\"You passed `ncomp` and `nseg`. Please pass only `ncomp`.\")\n        if ncomp is None and nseg is not None:\n            ncomp = nseg\n\n        super().__init__()\n        assert (\n            isinstance(compartments, (Compartment, List)) or compartments is None\n        ), \"Only Compartment or List[Compartment] is allowed.\"\n        if isinstance(compartments, Compartment):\n            assert (\n                ncomp is not None\n            ), \"If `compartments` is not a list then you have to set `ncomp`.\"\n        compartments = Compartment() if compartments is None else compartments\n        ncomp = 1 if ncomp is None else ncomp\n\n        if isinstance(compartments, Compartment):\n            compartment_list = [compartments] * ncomp\n        else:\n            compartment_list = compartments\n\n        self.ncomp = len(compartment_list)\n        self.ncomp_per_branch = np.asarray([self.ncomp])\n        self.total_nbranches = 1\n        self.nbranches_per_cell = [1]\n        self._cumsum_nbranches = jnp.asarray([0, 1])\n        self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n        # Indexing.\n        self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n        self._append_params_and_states(self.branch_params, self.branch_states)\n        self.nodes[\"global_comp_index\"] = np.arange(self.ncomp).tolist()\n        self.nodes[\"global_branch_index\"] = [0] * self.ncomp\n        self.nodes[\"global_cell_index\"] = [0] * self.ncomp\n        self._update_local_indices()\n        self._init_view()\n\n        # Channels.\n        self._gather_channels_from_constituents(compartment_list)\n\n        self.branch_edges = pd.DataFrame(\n            dict(parent_branch_index=[], child_branch_index=[])\n        )\n\n        # For morphology indexing.\n        self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n            compute_children_and_parents(self.branch_edges)\n        )\n        self._internal_node_inds = jnp.arange(self.ncomp)\n\n        self._initialize()\n\n        # Coordinates.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n    def _init_morph_jaxley_spsolve(self):\n        self._solve_indexer = JaxleySolveIndexer(\n            cumsum_ncomp=self.cumsum_ncomp,\n            branchpoint_group_inds=np.asarray([]).astype(int),\n            remapped_node_indices=self._internal_node_inds,\n            children_in_level=[],\n            parents_in_level=[],\n            root_inds=np.asarray([0]),\n        )\n\n    def _init_morph_jax_spsolve(self):\n        \"\"\"Initialize morphology for the jax sparse voltage solver.\n\n        Explanation of `self._comp_eges['type']`:\n        `type == 0`: compartment <--> compartment (within branch)\n        `type == 1`: branchpoint --> parent-compartment\n        `type == 2`: branchpoint --> child-compartment\n        `type == 3`: parent-compartment --> branchpoint\n        `type == 4`: child-compartment --> branchpoint\n        \"\"\"\n        self._comp_edges = pd.DataFrame().from_dict(\n            {\n                \"source\": list(range(self.ncomp - 1)) + list(range(1, self.ncomp)),\n                \"sink\": list(range(1, self.ncomp)) + list(range(self.ncomp - 1)),\n            }\n        )\n        self._comp_edges[\"type\"] = 0\n        n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n        self._n_nodes = n_nodes\n        self._data_inds = data_inds\n        self._indices_jax_spsolve = indices\n        self._indptr_jax_spsolve = indptr\n\n    def __len__(self) -> int:\n        return self.ncomp\n
"},{"location":"reference/modules/#jaxley.modules.branch.Branch.__init__","title":"__init__(compartments=None, ncomp=None, nseg=None)","text":"

Parameters:

Name Type Description Default compartments Optional[Union[Compartment, List[Compartment]]]

A single compartment or a list of compartments that make up the branch.

None ncomp Optional[int]

Number of segments to divide the branch into. If compartments is an a single compartment, than the compartment is repeated ncomp times to create the branch.

None Source code in jaxley/modules/branch.py
@deprecated_kwargs(\"0.6.0\", [\"nseg\"])\ndef __init__(\n    self,\n    compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n    ncomp: Optional[int] = None,\n    nseg: Optional[int] = None,\n):\n    \"\"\"\n    Args:\n        compartments: A single compartment or a list of compartments that make up the\n            branch.\n        ncomp: Number of segments to divide the branch into. If `compartments` is an\n            a single compartment, than the compartment is repeated `ncomp` times to\n            create the branch.\n    \"\"\"\n    # Warnings and errors that deal with the change from `nseg` to `ncomp` change\n    # in Jaxley v0.5.0.\n    if ncomp is not None and nseg is not None:\n        raise ValueError(\"You passed `ncomp` and `nseg`. Please pass only `ncomp`.\")\n    if ncomp is None and nseg is not None:\n        ncomp = nseg\n\n    super().__init__()\n    assert (\n        isinstance(compartments, (Compartment, List)) or compartments is None\n    ), \"Only Compartment or List[Compartment] is allowed.\"\n    if isinstance(compartments, Compartment):\n        assert (\n            ncomp is not None\n        ), \"If `compartments` is not a list then you have to set `ncomp`.\"\n    compartments = Compartment() if compartments is None else compartments\n    ncomp = 1 if ncomp is None else ncomp\n\n    if isinstance(compartments, Compartment):\n        compartment_list = [compartments] * ncomp\n    else:\n        compartment_list = compartments\n\n    self.ncomp = len(compartment_list)\n    self.ncomp_per_branch = np.asarray([self.ncomp])\n    self.total_nbranches = 1\n    self.nbranches_per_cell = [1]\n    self._cumsum_nbranches = jnp.asarray([0, 1])\n    self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n    # Indexing.\n    self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n    self._append_params_and_states(self.branch_params, self.branch_states)\n    self.nodes[\"global_comp_index\"] = np.arange(self.ncomp).tolist()\n    self.nodes[\"global_branch_index\"] = [0] * self.ncomp\n    self.nodes[\"global_cell_index\"] = [0] * self.ncomp\n    self._update_local_indices()\n    self._init_view()\n\n    # Channels.\n    self._gather_channels_from_constituents(compartment_list)\n\n    self.branch_edges = pd.DataFrame(\n        dict(parent_branch_index=[], child_branch_index=[])\n    )\n\n    # For morphology indexing.\n    self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n        compute_children_and_parents(self.branch_edges)\n    )\n    self._internal_node_inds = jnp.arange(self.ncomp)\n\n    self._initialize()\n\n    # Coordinates.\n    self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n
"},{"location":"reference/modules/#cell","title":"Cell","text":"

Bases: Module

Cell class.

This class defines a single cell that can be simulated by itself or connected with synapses to build a network. A cell is made up of several branches and supports intricate cell morphologies.

Source code in jaxley/modules/cell.py
class Cell(Module):\n    \"\"\"Cell class.\n\n    This class defines a single cell that can be simulated by itself or\n    connected with synapses to build a network. A cell is made up of several branches\n    and supports intricate cell morphologies.\n    \"\"\"\n\n    cell_params: Dict = {}\n    cell_states: Dict = {}\n\n    def __init__(\n        self,\n        branches: Optional[Union[Branch, List[Branch]]] = None,\n        parents: Optional[List[int]] = None,\n        xyzr: Optional[List[np.ndarray]] = None,\n    ):\n        \"\"\"Initialize a cell.\n\n        Args:\n            branches: A single branch or a list of branches that make up the cell.\n                If a single branch is provided, then the branch is repeated `len(parents)`\n                times to create the cell.\n            parents: The parent branch index for each branch. The first branch has no\n                parent and is therefore set to -1.\n            xyzr: For every branch, the x, y, and z coordinates and the radius at the\n                traced coordinates. Note that this is the full tracing (from SWC), not\n                the stick representation coordinates.\n        \"\"\"\n        super().__init__()\n        assert (\n            isinstance(branches, (Branch, List)) or branches is None\n        ), \"Only Branch or List[Branch] is allowed.\"\n        if branches is not None:\n            assert (\n                parents is not None\n            ), \"If `branches` is not a list then you have to set `parents`.\"\n        if isinstance(branches, List):\n            assert len(parents) == len(\n                branches\n            ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n        branches = Branch() if branches is None else branches\n        parents = [-1] if parents is None else parents\n\n        if isinstance(branches, Branch):\n            branch_list = [branches for _ in range(len(parents))]\n        else:\n            branch_list = branches\n\n        if xyzr is not None:\n            assert len(xyzr) == len(parents)\n            self.xyzr = xyzr\n        else:\n            # For every branch (`len(parents)`), we have a start and end point (`2`) and\n            # a (x,y,z,r) coordinate for each of them (`4`).\n            # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n            # (potentially learned) length of every compartment, we only populate\n            # self.xyzr at `.vis()`.\n            self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n        self.total_nbranches = len(branch_list)\n        self.nbranches_per_cell = [len(branch_list)]\n        self.comb_parents = jnp.asarray(parents)\n        self.comb_children = compute_children_indices(self.comb_parents)\n        self._cumsum_nbranches = np.asarray([0, len(branch_list)])\n\n        # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()`\n        # is run.\n        self.ncomp_per_branch = np.asarray([branch.ncomp for branch in branch_list])\n        self.ncomp = int(np.max(self.ncomp_per_branch))\n        self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n        self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n\n        # Build nodes. Has to be changed when `.set_ncomp()` is run.\n        self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n        self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n        self.nodes[\"global_branch_index\"] = np.repeat(\n            np.arange(self.total_nbranches), self.ncomp_per_branch\n        ).tolist()\n        self.nodes[\"global_cell_index\"] = np.repeat(0, self.cumsum_ncomp[-1]).tolist()\n        self._update_local_indices()\n        self._init_view()\n\n        # Appending general parameters (radius, length, r_a, cm) and channel parameters,\n        # as well as the states (v, and channel states).\n        self._append_params_and_states(self.cell_params, self.cell_states)\n\n        # Channels.\n        self._gather_channels_from_constituents(branch_list)\n\n        self.branch_edges = pd.DataFrame(\n            dict(\n                parent_branch_index=self.comb_parents[1:],\n                child_branch_index=np.arange(1, self.total_nbranches),\n            )\n        )\n\n        # For morphology indexing.\n        self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n            compute_children_and_parents(self.branch_edges)\n        )\n\n        self._initialize()\n\n    def _init_morph_jaxley_spsolve(self):\n        \"\"\"Initialize morphology for the custom sparse solver.\n\n        Running this function is only required for custom Jaxley solvers, i.e., for\n        `voltage_solver={'jaxley.stone', 'jaxley.thomas'}`. However, because at\n        `.__init__()` (when the function is run), we do not yet know which solver the\n        user will use. Therefore, we always run this function at `.__init__()`.\n        \"\"\"\n        children_and_parents = compute_morphology_indices_in_levels(\n            len(self._par_inds),\n            self._child_belongs_to_branchpoint,\n            self._par_inds,\n            self._child_inds,\n        )\n        branchpoint_group_inds = build_branchpoint_group_inds(\n            len(self._par_inds),\n            self._child_belongs_to_branchpoint,\n            self.cumsum_ncomp[-1],\n        )\n        parents = self.comb_parents\n        children_inds = children_and_parents[\"children\"]\n        parents_inds = children_and_parents[\"parents\"]\n\n        levels = compute_levels(parents)\n        children_in_level = compute_children_in_level(levels, children_inds)\n        parents_in_level = compute_parents_in_level(\n            levels, self._par_inds, parents_inds\n        )\n        levels_and_ncomp = pd.DataFrame().from_dict(\n            {\n                \"levels\": levels,\n                \"ncomps\": self.ncomp_per_branch,\n            }\n        )\n        levels_and_ncomp[\"max_ncomp_in_level\"] = levels_and_ncomp.groupby(\"levels\")[\n            \"ncomps\"\n        ].transform(\"max\")\n        padded_cumsum_ncomp = cumsum_leading_zero(\n            levels_and_ncomp[\"max_ncomp_in_level\"].to_numpy()\n        )\n\n        # Generate mapping to deal with the masking which allows using the custom\n        # sparse solver to deal with different ncomp per branch.\n        remapped_node_indices = remap_index_to_masked(\n            self._internal_node_inds,\n            self.nodes,\n            padded_cumsum_ncomp,\n            self.ncomp_per_branch,\n        )\n        self._solve_indexer = JaxleySolveIndexer(\n            cumsum_ncomp=padded_cumsum_ncomp,\n            branchpoint_group_inds=branchpoint_group_inds,\n            children_in_level=children_in_level,\n            parents_in_level=parents_in_level,\n            root_inds=np.asarray([0]),\n            remapped_node_indices=remapped_node_indices,\n        )\n\n    def _init_morph_jax_spsolve(self):\n        \"\"\"For morphology indexing with the `jax.sparse` voltage volver.\n\n        Explanation of `self._comp_eges['type']`:\n        `type == 0`: compartment <--> compartment (within branch)\n        `type == 1`: branchpoint --> parent-compartment\n        `type == 2`: branchpoint --> child-compartment\n        `type == 3`: parent-compartment --> branchpoint\n        `type == 4`: child-compartment --> branchpoint\n\n        Running this function is only required for generic sparse solvers, i.e., for\n        `voltage_solver='jax.sparse'`.\n        \"\"\"\n\n        # Edges between compartments within the branches.\n        self._comp_edges = pd.concat(\n            [\n                pd.DataFrame()\n                .from_dict(\n                    {\n                        \"source\": list(range(cumsum_ncomp, ncomp - 1 + cumsum_ncomp))\n                        + list(range(1 + cumsum_ncomp, ncomp + cumsum_ncomp)),\n                        \"sink\": list(range(1 + cumsum_ncomp, ncomp + cumsum_ncomp))\n                        + list(range(cumsum_ncomp, ncomp - 1 + cumsum_ncomp)),\n                    }\n                )\n                .astype(int)\n                for ncomp, cumsum_ncomp in zip(self.ncomp_per_branch, self.cumsum_ncomp)\n            ]\n        )\n        self._comp_edges[\"type\"] = 0\n\n        # Edges from branchpoints to compartments.\n        branchpoint_to_parent_edges = pd.DataFrame().from_dict(\n            {\n                \"source\": np.arange(len(self._par_inds)) + self.cumsum_ncomp[-1],\n                \"sink\": self.cumsum_ncomp[self._par_inds + 1] - 1,\n                \"type\": 1,\n            }\n        )\n        branchpoint_to_child_edges = pd.DataFrame().from_dict(\n            {\n                \"source\": self._child_belongs_to_branchpoint + self.cumsum_ncomp[-1],\n                \"sink\": self.cumsum_ncomp[self._child_inds],\n                \"type\": 2,\n            }\n        )\n        self._comp_edges = pd.concat(\n            [\n                self._comp_edges,\n                branchpoint_to_parent_edges,\n                branchpoint_to_child_edges,\n            ],\n            ignore_index=True,\n        )\n\n        # Edges from compartments to branchpoints.\n        parent_to_branchpoint_edges = branchpoint_to_parent_edges.rename(\n            columns={\"sink\": \"source\", \"source\": \"sink\"}\n        )\n        parent_to_branchpoint_edges[\"type\"] = 3\n        child_to_branchpoint_edges = branchpoint_to_child_edges.rename(\n            columns={\"sink\": \"source\", \"source\": \"sink\"}\n        )\n        child_to_branchpoint_edges[\"type\"] = 4\n\n        self._comp_edges = pd.concat(\n            [\n                self._comp_edges,\n                parent_to_branchpoint_edges,\n                child_to_branchpoint_edges,\n            ],\n            ignore_index=True,\n        )\n\n        n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n        self._n_nodes = n_nodes\n        self._data_inds = data_inds\n        self._indices_jax_spsolve = indices\n        self._indptr_jax_spsolve = indptr\n
"},{"location":"reference/modules/#jaxley.modules.cell.Cell.__init__","title":"__init__(branches=None, parents=None, xyzr=None)","text":"

Initialize a cell.

Parameters:

Name Type Description Default branches Optional[Union[Branch, List[Branch]]]

A single branch or a list of branches that make up the cell. If a single branch is provided, then the branch is repeated len(parents) times to create the cell.

None parents Optional[List[int]]

The parent branch index for each branch. The first branch has no parent and is therefore set to -1.

None xyzr Optional[List[ndarray]]

For every branch, the x, y, and z coordinates and the radius at the traced coordinates. Note that this is the full tracing (from SWC), not the stick representation coordinates.

None Source code in jaxley/modules/cell.py
def __init__(\n    self,\n    branches: Optional[Union[Branch, List[Branch]]] = None,\n    parents: Optional[List[int]] = None,\n    xyzr: Optional[List[np.ndarray]] = None,\n):\n    \"\"\"Initialize a cell.\n\n    Args:\n        branches: A single branch or a list of branches that make up the cell.\n            If a single branch is provided, then the branch is repeated `len(parents)`\n            times to create the cell.\n        parents: The parent branch index for each branch. The first branch has no\n            parent and is therefore set to -1.\n        xyzr: For every branch, the x, y, and z coordinates and the radius at the\n            traced coordinates. Note that this is the full tracing (from SWC), not\n            the stick representation coordinates.\n    \"\"\"\n    super().__init__()\n    assert (\n        isinstance(branches, (Branch, List)) or branches is None\n    ), \"Only Branch or List[Branch] is allowed.\"\n    if branches is not None:\n        assert (\n            parents is not None\n        ), \"If `branches` is not a list then you have to set `parents`.\"\n    if isinstance(branches, List):\n        assert len(parents) == len(\n            branches\n        ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n    branches = Branch() if branches is None else branches\n    parents = [-1] if parents is None else parents\n\n    if isinstance(branches, Branch):\n        branch_list = [branches for _ in range(len(parents))]\n    else:\n        branch_list = branches\n\n    if xyzr is not None:\n        assert len(xyzr) == len(parents)\n        self.xyzr = xyzr\n    else:\n        # For every branch (`len(parents)`), we have a start and end point (`2`) and\n        # a (x,y,z,r) coordinate for each of them (`4`).\n        # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n        # (potentially learned) length of every compartment, we only populate\n        # self.xyzr at `.vis()`.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n    self.total_nbranches = len(branch_list)\n    self.nbranches_per_cell = [len(branch_list)]\n    self.comb_parents = jnp.asarray(parents)\n    self.comb_children = compute_children_indices(self.comb_parents)\n    self._cumsum_nbranches = np.asarray([0, len(branch_list)])\n\n    # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()`\n    # is run.\n    self.ncomp_per_branch = np.asarray([branch.ncomp for branch in branch_list])\n    self.ncomp = int(np.max(self.ncomp_per_branch))\n    self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n    self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n\n    # Build nodes. Has to be changed when `.set_ncomp()` is run.\n    self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n    self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n    self.nodes[\"global_branch_index\"] = np.repeat(\n        np.arange(self.total_nbranches), self.ncomp_per_branch\n    ).tolist()\n    self.nodes[\"global_cell_index\"] = np.repeat(0, self.cumsum_ncomp[-1]).tolist()\n    self._update_local_indices()\n    self._init_view()\n\n    # Appending general parameters (radius, length, r_a, cm) and channel parameters,\n    # as well as the states (v, and channel states).\n    self._append_params_and_states(self.cell_params, self.cell_states)\n\n    # Channels.\n    self._gather_channels_from_constituents(branch_list)\n\n    self.branch_edges = pd.DataFrame(\n        dict(\n            parent_branch_index=self.comb_parents[1:],\n            child_branch_index=np.arange(1, self.total_nbranches),\n        )\n    )\n\n    # For morphology indexing.\n    self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n        compute_children_and_parents(self.branch_edges)\n    )\n\n    self._initialize()\n
"},{"location":"reference/modules/#network","title":"Network","text":"

Bases: Module

Network class.

This class defines a network of cells that can be connected with synapses.

Source code in jaxley/modules/network.py
class Network(Module):\n    \"\"\"Network class.\n\n    This class defines a network of cells that can be connected with synapses.\n    \"\"\"\n\n    network_params: Dict = {}\n    network_states: Dict = {}\n\n    def __init__(\n        self,\n        cells: List[Cell],\n    ):\n        \"\"\"Initialize network of cells and synapses.\n\n        Args:\n            cells: A list of cells that make up the network.\n        \"\"\"\n        super().__init__()\n        for cell in cells:\n            self.xyzr += deepcopy(cell.xyzr)\n\n        self._cells_list = cells\n        self.ncomp_per_branch = np.concatenate(\n            [cell.ncomp_per_branch for cell in cells]\n        )\n        self.ncomp = int(np.max(self.ncomp_per_branch))\n        self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n        self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n        self._append_params_and_states(self.network_params, self.network_states)\n\n        self.nbranches_per_cell = [cell.total_nbranches for cell in cells]\n        self.total_nbranches = sum(self.nbranches_per_cell)\n        self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell)\n\n        self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n        self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n        self.nodes[\"global_branch_index\"] = np.repeat(\n            np.arange(self.total_nbranches), self.ncomp_per_branch\n        ).tolist()\n        self.nodes[\"global_cell_index\"] = list(\n            itertools.chain(\n                *[[i] * int(cell.cumsum_ncomp[-1]) for i, cell in enumerate(cells)]\n            )\n        )\n        self._update_local_indices()\n        self._init_view()\n\n        parents = [cell.comb_parents for cell in cells]\n        self.comb_parents = jnp.concatenate(\n            [p.at[1:].add(self._cumsum_nbranches[i]) for i, p in enumerate(parents)]\n        )\n\n        # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n        # branch, apart from those branches which do not have a parent (i.e.\n        # -1 in parents). For every branch, tracks the global index of that branch\n        # (`child_branch_index`) and the global index of its parent\n        # (`parent_branch_index`).\n        self.branch_edges = pd.DataFrame(\n            dict(\n                parent_branch_index=self.comb_parents[self.comb_parents != -1],\n                child_branch_index=np.where(self.comb_parents != -1)[0],\n            )\n        )\n\n        # For morphology indexing of both `jax.sparse` and the custom `jaxley` solvers.\n        self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n            compute_children_and_parents(self.branch_edges)\n        )\n\n        # `nbranchpoints` in each cell == cell._par_inds (because `par_inds` are unique).\n        nbranchpoints = jnp.asarray([len(cell._par_inds) for cell in cells])\n        self._cumsum_nbranchpoints_per_cell = cumsum_leading_zero(nbranchpoints)\n\n        # Channels.\n        self._gather_channels_from_constituents(cells)\n\n        self._initialize()\n        del self._cells_list\n\n    def __repr__(self):\n        return f\"{type(self).__name__} with {len(self.channels)} different channels and {len(self.synapses)} synapses. Use `.nodes` or `.edges` for details.\"\n\n    def _init_morph_jaxley_spsolve(self):\n        branchpoint_group_inds = build_branchpoint_group_inds(\n            len(self._par_inds),\n            self._child_belongs_to_branchpoint,\n            self.cumsum_ncomp[-1],\n        )\n        children_in_level = merge_cells(\n            self._cumsum_nbranches,\n            self._cumsum_nbranchpoints_per_cell,\n            [cell._solve_indexer.children_in_level for cell in self._cells_list],\n            exclude_first=False,\n        )\n        parents_in_level = merge_cells(\n            self._cumsum_nbranches,\n            self._cumsum_nbranchpoints_per_cell,\n            [cell._solve_indexer.parents_in_level for cell in self._cells_list],\n            exclude_first=False,\n        )\n        padded_cumsum_ncomp = cumsum_leading_zero(\n            np.concatenate(\n                [np.diff(cell._solve_indexer.cumsum_ncomp) for cell in self._cells_list]\n            )\n        )\n\n        # Generate mapping to dealing with the masking which allows using the custom\n        # sparse solver to deal with different ncomp per branch.\n        remapped_node_indices = remap_index_to_masked(\n            self._internal_node_inds,\n            self.nodes,\n            padded_cumsum_ncomp,\n            self.ncomp_per_branch,\n        )\n        self._solve_indexer = JaxleySolveIndexer(\n            cumsum_ncomp=padded_cumsum_ncomp,\n            branchpoint_group_inds=branchpoint_group_inds,\n            children_in_level=children_in_level,\n            parents_in_level=parents_in_level,\n            root_inds=self._cumsum_nbranches[:-1],\n            remapped_node_indices=remapped_node_indices,\n        )\n\n    def _init_morph_jax_spsolve(self):\n        \"\"\"Initialize the morphology for networks.\n\n        The reason that this function is a bit involved for a `Network` is that Jaxley\n        considers branchpoint nodes to be at the very end of __all__ nodes (i.e. the\n        branchpoints of the first cell are even after the compartments of the second\n        cell. The reason for this is that, otherwise, `cumsum_ncomp` becomes tricky).\n\n        To achieve this, we first loop over all compartments and append them, and then\n        loop over all branchpoints and append those. The code for building the indices\n        from the `comp_edges` is identical to `jx.Cell`.\n\n        Explanation of `self._comp_eges['type']`:\n        `type == 0`: compartment <--> compartment (within branch)\n        `type == 1`: branchpoint --> parent-compartment\n        `type == 2`: branchpoint --> child-compartment\n        `type == 3`: parent-compartment --> branchpoint\n        `type == 4`: child-compartment --> branchpoint\n        \"\"\"\n        self._cumsum_ncomp_per_cell = cumsum_leading_zero(\n            jnp.asarray([cell.cumsum_ncomp[-1] for cell in self.cells])\n        )\n        self._comp_edges = pd.DataFrame()\n\n        # Add all the internal nodes.\n        for offset, cell in zip(self._cumsum_ncomp_per_cell, self._cells_list):\n            condition = cell._comp_edges[\"type\"].to_numpy() == 0\n            rows = cell._comp_edges[condition]\n            self._comp_edges = pd.concat(\n                [self._comp_edges, [offset, offset, 0] + rows], ignore_index=True\n            )\n\n        # All branchpoint-to-compartment nodes.\n        start_branchpoints = self.cumsum_ncomp[-1]  # Index of the first branchpoint.\n        for offset, offset_branchpoints, cell in zip(\n            self._cumsum_ncomp_per_cell,\n            self._cumsum_nbranchpoints_per_cell,\n            self._cells_list,\n        ):\n            offset_within_cell = cell.cumsum_ncomp[-1]\n            condition = cell._comp_edges[\"type\"].isin([1, 2])\n            rows = cell._comp_edges[condition]\n            self._comp_edges = pd.concat(\n                [\n                    self._comp_edges,\n                    [\n                        start_branchpoints - offset_within_cell + offset_branchpoints,\n                        offset,\n                        0,\n                    ]\n                    + rows,\n                ],\n                ignore_index=True,\n            )\n\n        # All compartment-to-branchpoint nodes.\n        for offset, offset_branchpoints, cell in zip(\n            self._cumsum_ncomp_per_cell,\n            self._cumsum_nbranchpoints_per_cell,\n            self._cells_list,\n        ):\n            offset_within_cell = cell.cumsum_ncomp[-1]\n            condition = cell._comp_edges[\"type\"].isin([3, 4])\n            rows = cell._comp_edges[condition]\n            self._comp_edges = pd.concat(\n                [\n                    self._comp_edges,\n                    [\n                        offset,\n                        start_branchpoints - offset_within_cell + offset_branchpoints,\n                        0,\n                    ]\n                    + rows,\n                ],\n                ignore_index=True,\n            )\n\n        # Convert comp_edges to the index format required for `jax.sparse` solvers.\n        n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n        self._n_nodes = n_nodes\n        self._data_inds = data_inds\n        self._indices_jax_spsolve = indices\n        self._indptr_jax_spsolve = indptr\n\n    def _step_synapse(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Perform one step of the synapses and obtain their currents.\"\"\"\n        states = self._step_synapse_state(states, syn_channels, params, delta_t, edges)\n        states, current_terms = self._synapse_currents(\n            states, syn_channels, params, delta_t, edges\n        )\n        return states, current_terms\n\n    def _step_synapse_state(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Dict:\n        voltages = states[\"v\"]\n\n        grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n        pre_syn_inds = grouped_syns[\"pre_global_comp_index\"].apply(list)\n        post_syn_inds = grouped_syns[\"post_global_comp_index\"].apply(list)\n        synapse_names = list(grouped_syns.indices.keys())\n\n        for i, synapse_type in enumerate(syn_channels):\n            assert (\n                synapse_names[i] == synapse_type._name\n            ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n            synapse_param_names = list(synapse_type.synapse_params.keys())\n            synapse_state_names = list(synapse_type.synapse_states.keys())\n\n            synapse_params = {}\n            for p in synapse_param_names:\n                synapse_params[p] = params[p]\n            synapse_states = {}\n            for s in synapse_state_names:\n                synapse_states[s] = states[s]\n\n            pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n            post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n            # State updates.\n            states_updated = synapse_type.update_states(\n                synapse_states,\n                delta_t,\n                voltages[pre_inds],\n                voltages[post_inds],\n                synapse_params,\n            )\n\n            # Rebuild state.\n            for key, val in states_updated.items():\n                states[key] = val\n\n        return states\n\n    def _synapse_currents(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n        voltages = states[\"v\"]\n\n        grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n        pre_syn_inds = grouped_syns[\"pre_global_comp_index\"].apply(list)\n        post_syn_inds = grouped_syns[\"post_global_comp_index\"].apply(list)\n        synapse_names = list(grouped_syns.indices.keys())\n\n        syn_voltage_terms = jnp.zeros_like(voltages)\n        syn_constant_terms = jnp.zeros_like(voltages)\n        # Run with two different voltages that are `diff` apart to infer the slope and\n        # offset.\n        diff = 1e-3\n        for i, synapse_type in enumerate(syn_channels):\n            assert (\n                synapse_names[i] == synapse_type._name\n            ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n            synapse_param_names = list(synapse_type.synapse_params.keys())\n            synapse_state_names = list(synapse_type.synapse_states.keys())\n\n            synapse_params = {}\n            for p in synapse_param_names:\n                synapse_params[p] = params[p]\n            synapse_states = {}\n            for s in synapse_state_names:\n                synapse_states[s] = states[s]\n\n            # Get pre and post indexes of the current synapse type.\n            pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n            post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n            # Compute slope and offset of the current through every synapse.\n            pre_v_and_perturbed = jnp.stack(\n                [voltages[pre_inds], voltages[pre_inds] + diff]\n            )\n            post_v_and_perturbed = jnp.stack(\n                [voltages[post_inds], voltages[post_inds] + diff]\n            )\n            synapse_currents = vmap(\n                synapse_type.compute_current, in_axes=(None, 0, 0, None)\n            )(\n                synapse_states,\n                pre_v_and_perturbed,\n                post_v_and_perturbed,\n                synapse_params,\n            )\n            synapse_currents_dist = convert_point_process_to_distributed(\n                synapse_currents,\n                params[\"radius\"][post_inds],\n                params[\"length\"][post_inds],\n            )\n\n            # Split into voltage and constant terms.\n            voltage_term = (synapse_currents_dist[1] - synapse_currents_dist[0]) / diff\n            constant_term = (\n                synapse_currents_dist[0] - voltage_term * voltages[post_inds]\n            )\n\n            # Gather slope and offset for every postsynaptic compartment.\n            gathered_syn_currents = gather_synapes(\n                len(voltages),\n                post_inds,\n                voltage_term,\n                constant_term,\n            )\n            syn_voltage_terms += gathered_syn_currents[0]\n            syn_constant_terms -= gathered_syn_currents[1]\n\n            # Add the synaptic currents through every compartment as state.\n            # `post_syn_currents` is a `jnp.ndarray` of as many elements as there are\n            # compartments in the network.\n            # `[0]` because we only use the non-perturbed voltage.\n            states[f\"i_{synapse_type._name}\"] = synapse_currents[0]\n\n        return states, (syn_voltage_terms, syn_constant_terms)\n\n    def arrange_in_layers(\n        self,\n        layers: List[int],\n        within_layer_offset: float = 500.0,\n        between_layer_offset: float = 1500.0,\n        vertical_layers: bool = False,\n    ):\n        \"\"\"Arrange the cells in the network to form layers.\n\n        Moves the cells in the network to arrange them into layers.\n\n        Args:\n            layers: List of integers specifying the number of cells in each layer.\n            within_layer_offset: Offset between cells within the same layer.\n            between_layer_offset: Offset between layers.\n            vertical_layers: If True, layers are arranged vertically.\n        \"\"\"\n        assert (\n            np.sum(layers) == self.shape[0]\n        ), \"The number of cells in the layers must match the number of cells in the network.\"\n        cells_in_layers = [\n            list(range(sum(layers[:i]), sum(layers[: i + 1])))\n            for i in range(len(layers))\n        ]\n\n        for l, cell_inds in enumerate(cells_in_layers):\n            layer = self.cell(cell_inds)\n            for i, cell in enumerate(layer.cells):\n                if vertical_layers:\n                    x_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset\n                    y_offset = (len(layers) - 1 - l) * between_layer_offset\n                else:\n                    x_offset = l * between_layer_offset\n                    y_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset\n\n                cell.move_to(x=x_offset, y=y_offset, z=0)\n\n    def vis(\n        self,\n        detail: str = \"full\",\n        ax: Optional[Axes] = None,\n        color: str = \"k\",\n        synapse_color: str = \"b\",\n        dims: Tuple[int] = (0, 1),\n        type: str = \"line\",\n        cell_plot_kwargs: Dict = {},\n        synapse_plot_kwargs: Dict = {},\n    ) -> Axes:\n        \"\"\"Visualize the module.\n\n        Args:\n            detail: Either of [point, full]. `point` visualizes every neuron in the\n                network as a dot.\n                `full` plots the full morphology of every neuron. It requires that\n                `compute_xyz()` has been run and allows for indivual neurons to be\n                moved with `.move()`.\n            color: The color in which cells are plotted. Only takes effect if\n                `detail='full'`.\n            type: Either `line` or `scatter`. Only takes effect if `detail='full'`.\n            synapse_color: The color in which synapses are plotted. Only takes effect if\n                `detail='full'`.\n            dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n                two of them.\n            cell_plot_kwargs: Keyword arguments passed to the plotting function for\n                cell morphologies. Only takes effect for `detail='full'`.\n            synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n                syanpses. Only takes effect for `detail='full'`.\n        \"\"\"\n        xyz0 = self.cell(0).xyzr[0][:, :3]\n        same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells])\n        if same_xyz:\n            warn(\n                \"Same coordinates for all cells. Consider using `move`, `move_to` or `arrange_in_layers` to move them.\"\n            )\n\n        if ax is None:\n            fig = plt.figure(figsize=(3, 3))\n            ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n        # detail=\"point\" -> pos taken to be the mean of all traced points on the cell.\n        cell_to_point_xyz = lambda cell: np.mean(np.vstack(cell.xyzr)[:, :3], axis=0)\n\n        dims_np = np.asarray(dims)\n        if detail == \"point\":\n            for cell in self.cells:\n                pos = cell_to_point_xyz(cell)[dims_np]\n                ax.scatter(*pos, color=color, **cell_plot_kwargs)\n        elif detail == \"full\":\n            ax = super().vis(\n                dims=dims, color=color, ax=ax, type=type, **cell_plot_kwargs\n            )\n        else:\n            raise ValueError(\"detail must be in {full, point}.\")\n\n        nodes = self.nodes.set_index(\"global_comp_index\")\n        for i, edge in self.edges.iterrows():\n            prepost_locs = []\n            for prepost in [\"pre\", \"post\"]:\n                loc, comp = edge[[prepost + \"_locs\", prepost + \"_global_comp_index\"]]\n                branch = nodes.loc[comp, \"global_branch_index\"]\n                cell = nodes.loc[comp, \"global_cell_index\"]\n                branch_xyz = self.xyzr[branch]\n\n                xyz_loc = branch_xyz\n                if detail == \"point\":\n                    xyz_loc = cell_to_point_xyz(self.cell(cell))\n                elif len(branch_xyz) == 2:\n                    # If only start and end point of a branch are traced, perform a\n                    # linear interpolation to get the synpase location.\n                    xyz_loc = branch_xyz[0] + (branch_xyz[1] - branch_xyz[0]) * loc\n                else:\n                    # If densely traced, use intermediate trace values for synapse loc.\n                    middle_ind = int((len(branch_xyz) - 1) * loc)\n                    xyz_loc = xyz_loc[middle_ind]\n\n                prepost_locs.append(xyz_loc)\n            prepost_locs = np.stack(prepost_locs).T\n\n            ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs)\n\n        return ax\n\n    def _infer_synapse_type_ind(self, synapse_name):\n        syn_names = self.base.synapse_names\n        is_new_type = False if synapse_name in syn_names else True\n        type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name)\n        return type_ind, is_new_type\n\n    def _update_synapse_state_names(self, synapse_type):\n        # (Potentially) update variables that track meta information about synapses.\n        self.base.synapse_names.append(synapse_type._name)\n        self.base.synapse_param_names += list(synapse_type.synapse_params.keys())\n        self.base.synapse_state_names += list(synapse_type.synapse_states.keys())\n        self.base.synapses.append(synapse_type)\n\n    def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):\n        # Add synapse types to the module and infer their unique identifier.\n        synapse_name = synapse_type._name\n        synapse_current_name = f\"i_{synapse_name}\"\n        type_ind, is_new = self._infer_synapse_type_ind(synapse_name)\n        if is_new:  # synapse is not known\n            self._update_synapse_state_names(synapse_type)\n            self.base.synapse_current_names.append(synapse_current_name)\n\n        index = len(self.base.edges)\n        indices = [idx for idx in range(index, index + len(pre_nodes))]\n        global_edge_index = pd.DataFrame({\"global_edge_index\": indices})\n        post_loc = loc_of_index(\n            post_nodes[\"global_comp_index\"].to_numpy(),\n            post_nodes[\"global_branch_index\"].to_numpy(),\n            self.ncomp_per_branch,\n        )\n        pre_loc = loc_of_index(\n            pre_nodes[\"global_comp_index\"].to_numpy(),\n            pre_nodes[\"global_branch_index\"].to_numpy(),\n            self.ncomp_per_branch,\n        )\n\n        # Define new synapses. Each row is one synapse.\n        pre_nodes = pre_nodes[[\"global_comp_index\"]]\n        pre_nodes.columns = [\"pre_global_comp_index\"]\n        post_nodes = post_nodes[[\"global_comp_index\"]]\n        post_nodes.columns = [\"post_global_comp_index\"]\n        new_rows = pd.concat(\n            [\n                global_edge_index,\n                pre_nodes.reset_index(drop=True),\n                post_nodes.reset_index(drop=True),\n            ],\n            axis=1,\n        )\n        new_rows[\"type\"] = synapse_name\n        new_rows[\"type_ind\"] = type_ind\n        new_rows[\"pre_locs\"] = pre_loc\n        new_rows[\"post_locs\"] = post_loc\n        self.base.edges = concat_and_ignore_empty(\n            [self.base.edges, new_rows], ignore_index=True, axis=0\n        )\n        self._add_params_to_edges(synapse_type, indices)\n        self.base.edges[\"controlled_by_param\"] = 0\n        self._edges_in_view = self.edges.index.to_numpy()\n\n    def _add_params_to_edges(self, synapse_type, indices):\n        # Add parameters and states to the `.edges` table.\n        for key, param_val in synapse_type.synapse_params.items():\n            self.base.edges.loc[indices, key] = param_val\n\n        # Update synaptic state array.\n        for key, state_val in synapse_type.synapse_states.items():\n            self.base.edges.loc[indices, key] = state_val\n
"},{"location":"reference/modules/#jaxley.modules.network.Network.__init__","title":"__init__(cells)","text":"

Initialize network of cells and synapses.

Parameters:

Name Type Description Default cells List[Cell]

A list of cells that make up the network.

required Source code in jaxley/modules/network.py
def __init__(\n    self,\n    cells: List[Cell],\n):\n    \"\"\"Initialize network of cells and synapses.\n\n    Args:\n        cells: A list of cells that make up the network.\n    \"\"\"\n    super().__init__()\n    for cell in cells:\n        self.xyzr += deepcopy(cell.xyzr)\n\n    self._cells_list = cells\n    self.ncomp_per_branch = np.concatenate(\n        [cell.ncomp_per_branch for cell in cells]\n    )\n    self.ncomp = int(np.max(self.ncomp_per_branch))\n    self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n    self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n    self._append_params_and_states(self.network_params, self.network_states)\n\n    self.nbranches_per_cell = [cell.total_nbranches for cell in cells]\n    self.total_nbranches = sum(self.nbranches_per_cell)\n    self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell)\n\n    self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n    self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n    self.nodes[\"global_branch_index\"] = np.repeat(\n        np.arange(self.total_nbranches), self.ncomp_per_branch\n    ).tolist()\n    self.nodes[\"global_cell_index\"] = list(\n        itertools.chain(\n            *[[i] * int(cell.cumsum_ncomp[-1]) for i, cell in enumerate(cells)]\n        )\n    )\n    self._update_local_indices()\n    self._init_view()\n\n    parents = [cell.comb_parents for cell in cells]\n    self.comb_parents = jnp.concatenate(\n        [p.at[1:].add(self._cumsum_nbranches[i]) for i, p in enumerate(parents)]\n    )\n\n    # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n    # branch, apart from those branches which do not have a parent (i.e.\n    # -1 in parents). For every branch, tracks the global index of that branch\n    # (`child_branch_index`) and the global index of its parent\n    # (`parent_branch_index`).\n    self.branch_edges = pd.DataFrame(\n        dict(\n            parent_branch_index=self.comb_parents[self.comb_parents != -1],\n            child_branch_index=np.where(self.comb_parents != -1)[0],\n        )\n    )\n\n    # For morphology indexing of both `jax.sparse` and the custom `jaxley` solvers.\n    self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n        compute_children_and_parents(self.branch_edges)\n    )\n\n    # `nbranchpoints` in each cell == cell._par_inds (because `par_inds` are unique).\n    nbranchpoints = jnp.asarray([len(cell._par_inds) for cell in cells])\n    self._cumsum_nbranchpoints_per_cell = cumsum_leading_zero(nbranchpoints)\n\n    # Channels.\n    self._gather_channels_from_constituents(cells)\n\n    self._initialize()\n    del self._cells_list\n
"},{"location":"reference/modules/#jaxley.modules.network.Network.arrange_in_layers","title":"arrange_in_layers(layers, within_layer_offset=500.0, between_layer_offset=1500.0, vertical_layers=False)","text":"

Arrange the cells in the network to form layers.

Moves the cells in the network to arrange them into layers.

Parameters:

Name Type Description Default layers List[int]

List of integers specifying the number of cells in each layer.

required within_layer_offset float

Offset between cells within the same layer.

500.0 between_layer_offset float

Offset between layers.

1500.0 vertical_layers bool

If True, layers are arranged vertically.

False Source code in jaxley/modules/network.py
def arrange_in_layers(\n    self,\n    layers: List[int],\n    within_layer_offset: float = 500.0,\n    between_layer_offset: float = 1500.0,\n    vertical_layers: bool = False,\n):\n    \"\"\"Arrange the cells in the network to form layers.\n\n    Moves the cells in the network to arrange them into layers.\n\n    Args:\n        layers: List of integers specifying the number of cells in each layer.\n        within_layer_offset: Offset between cells within the same layer.\n        between_layer_offset: Offset between layers.\n        vertical_layers: If True, layers are arranged vertically.\n    \"\"\"\n    assert (\n        np.sum(layers) == self.shape[0]\n    ), \"The number of cells in the layers must match the number of cells in the network.\"\n    cells_in_layers = [\n        list(range(sum(layers[:i]), sum(layers[: i + 1])))\n        for i in range(len(layers))\n    ]\n\n    for l, cell_inds in enumerate(cells_in_layers):\n        layer = self.cell(cell_inds)\n        for i, cell in enumerate(layer.cells):\n            if vertical_layers:\n                x_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset\n                y_offset = (len(layers) - 1 - l) * between_layer_offset\n            else:\n                x_offset = l * between_layer_offset\n                y_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset\n\n            cell.move_to(x=x_offset, y=y_offset, z=0)\n
"},{"location":"reference/modules/#jaxley.modules.network.Network.vis","title":"vis(detail='full', ax=None, color='k', synapse_color='b', dims=(0, 1), type='line', cell_plot_kwargs={}, synapse_plot_kwargs={})","text":"

Visualize the module.

Parameters:

Name Type Description Default detail str

Either of [point, full]. point visualizes every neuron in the network as a dot. full plots the full morphology of every neuron. It requires that compute_xyz() has been run and allows for indivual neurons to be moved with .move().

'full' color str

The color in which cells are plotted. Only takes effect if detail='full'.

'k' type str

Either line or scatter. Only takes effect if detail='full'.

'line' synapse_color str

The color in which synapses are plotted. Only takes effect if detail='full'.

'b' dims Tuple[int]

Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

(0, 1) cell_plot_kwargs Dict

Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for detail='full'.

{} synapse_plot_kwargs Dict

Keyword arguments passed to the plotting function for syanpses. Only takes effect for detail='full'.

{} Source code in jaxley/modules/network.py
def vis(\n    self,\n    detail: str = \"full\",\n    ax: Optional[Axes] = None,\n    color: str = \"k\",\n    synapse_color: str = \"b\",\n    dims: Tuple[int] = (0, 1),\n    type: str = \"line\",\n    cell_plot_kwargs: Dict = {},\n    synapse_plot_kwargs: Dict = {},\n) -> Axes:\n    \"\"\"Visualize the module.\n\n    Args:\n        detail: Either of [point, full]. `point` visualizes every neuron in the\n            network as a dot.\n            `full` plots the full morphology of every neuron. It requires that\n            `compute_xyz()` has been run and allows for indivual neurons to be\n            moved with `.move()`.\n        color: The color in which cells are plotted. Only takes effect if\n            `detail='full'`.\n        type: Either `line` or `scatter`. Only takes effect if `detail='full'`.\n        synapse_color: The color in which synapses are plotted. Only takes effect if\n            `detail='full'`.\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two of them.\n        cell_plot_kwargs: Keyword arguments passed to the plotting function for\n            cell morphologies. Only takes effect for `detail='full'`.\n        synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n            syanpses. Only takes effect for `detail='full'`.\n    \"\"\"\n    xyz0 = self.cell(0).xyzr[0][:, :3]\n    same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells])\n    if same_xyz:\n        warn(\n            \"Same coordinates for all cells. Consider using `move`, `move_to` or `arrange_in_layers` to move them.\"\n        )\n\n    if ax is None:\n        fig = plt.figure(figsize=(3, 3))\n        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n    # detail=\"point\" -> pos taken to be the mean of all traced points on the cell.\n    cell_to_point_xyz = lambda cell: np.mean(np.vstack(cell.xyzr)[:, :3], axis=0)\n\n    dims_np = np.asarray(dims)\n    if detail == \"point\":\n        for cell in self.cells:\n            pos = cell_to_point_xyz(cell)[dims_np]\n            ax.scatter(*pos, color=color, **cell_plot_kwargs)\n    elif detail == \"full\":\n        ax = super().vis(\n            dims=dims, color=color, ax=ax, type=type, **cell_plot_kwargs\n        )\n    else:\n        raise ValueError(\"detail must be in {full, point}.\")\n\n    nodes = self.nodes.set_index(\"global_comp_index\")\n    for i, edge in self.edges.iterrows():\n        prepost_locs = []\n        for prepost in [\"pre\", \"post\"]:\n            loc, comp = edge[[prepost + \"_locs\", prepost + \"_global_comp_index\"]]\n            branch = nodes.loc[comp, \"global_branch_index\"]\n            cell = nodes.loc[comp, \"global_cell_index\"]\n            branch_xyz = self.xyzr[branch]\n\n            xyz_loc = branch_xyz\n            if detail == \"point\":\n                xyz_loc = cell_to_point_xyz(self.cell(cell))\n            elif len(branch_xyz) == 2:\n                # If only start and end point of a branch are traced, perform a\n                # linear interpolation to get the synpase location.\n                xyz_loc = branch_xyz[0] + (branch_xyz[1] - branch_xyz[0]) * loc\n            else:\n                # If densely traced, use intermediate trace values for synapse loc.\n                middle_ind = int((len(branch_xyz) - 1) * loc)\n                xyz_loc = xyz_loc[middle_ind]\n\n            prepost_locs.append(xyz_loc)\n        prepost_locs = np.stack(prepost_locs).T\n\n        ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs)\n\n    return ax\n
"},{"location":"reference/optimize/","title":"Optimization","text":""},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer","title":"TypeOptimizer","text":"

optax wrapper which allows different argument values for different params.

Source code in jaxley/optimize/optimizer.py
class TypeOptimizer:\n    \"\"\"`optax` wrapper which allows different argument values for different params.\"\"\"\n\n    def __init__(\n        self,\n        optimizer: Callable,\n        optimizer_args: Dict[str, Any],\n        opt_params: List[Dict[str, jnp.ndarray]],\n    ):\n        \"\"\"Create the optimizers.\n\n        This requires access to `opt_params` in order to know how many optimizers\n        should be created. It creates `len(opt_params)` optimizers.\n\n        Example usage:\n        ```\n        lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n        optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n        opt_state = optimizer.init(opt_params)\n        ```\n\n        ```\n        optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n        optimizer = TypeOptimizer(\n            lambda args: optax.sgd(args[0], momentum=args[1]),\n            optimizer_args,\n            opt_params\n        )\n        opt_state = optimizer.init(opt_params)\n        ```\n\n        Args:\n            optimizer: A Callable that takes the learning rate and returns the\n                `optax.optimizer` which should be used.\n            optimizer_args: The arguments for different kinds of parameters.\n                Each item of the dictionary will be passed to the `Callable` passed to\n                `optimizer`.\n            opt_params: The parameters to be optimized. The exact values are not used,\n                only the number of elements in the list and the key of each dict.\n        \"\"\"\n        self.base_optimizer = optimizer\n\n        self.optimizers = []\n        for params in opt_params:\n            names = list(params.keys())\n            assert len(names) == 1, \"Multiple parameters were added at once.\"\n            name = names[0]\n            optimizer = self.base_optimizer(optimizer_args[name])\n            self.optimizers.append({name: optimizer})\n\n    def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n        \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n        opt_states = []\n        for params, optimizer in zip(opt_params, self.optimizers):\n            name = list(optimizer.keys())[0]\n            opt_state = optimizer[name].init(params)\n            opt_states.append(opt_state)\n        return opt_states\n\n    def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n        \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n        all_updates = []\n        new_opt_states = []\n        for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n            name = list(opt.keys())[0]\n            updates, new_opt_state = opt[name].update(grad, state)\n            all_updates.append(updates)\n            new_opt_states.append(new_opt_state)\n        return all_updates, new_opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.__init__","title":"__init__(optimizer, optimizer_args, opt_params)","text":"

Create the optimizers.

This requires access to opt_params in order to know how many optimizers should be created. It creates len(opt_params) optimizers.

Example usage:

lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\noptimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\nopt_state = optimizer.init(opt_params)\n

optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\noptimizer = TypeOptimizer(\n    lambda args: optax.sgd(args[0], momentum=args[1]),\n    optimizer_args,\n    opt_params\n)\nopt_state = optimizer.init(opt_params)\n

Parameters:

Name Type Description Default optimizer Callable

A Callable that takes the learning rate and returns the optax.optimizer which should be used.

required optimizer_args Dict[str, Any]

The arguments for different kinds of parameters. Each item of the dictionary will be passed to the Callable passed to optimizer.

required opt_params List[Dict[str, ndarray]]

The parameters to be optimized. The exact values are not used, only the number of elements in the list and the key of each dict.

required Source code in jaxley/optimize/optimizer.py
def __init__(\n    self,\n    optimizer: Callable,\n    optimizer_args: Dict[str, Any],\n    opt_params: List[Dict[str, jnp.ndarray]],\n):\n    \"\"\"Create the optimizers.\n\n    This requires access to `opt_params` in order to know how many optimizers\n    should be created. It creates `len(opt_params)` optimizers.\n\n    Example usage:\n    ```\n    lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n    optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n    opt_state = optimizer.init(opt_params)\n    ```\n\n    ```\n    optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n    optimizer = TypeOptimizer(\n        lambda args: optax.sgd(args[0], momentum=args[1]),\n        optimizer_args,\n        opt_params\n    )\n    opt_state = optimizer.init(opt_params)\n    ```\n\n    Args:\n        optimizer: A Callable that takes the learning rate and returns the\n            `optax.optimizer` which should be used.\n        optimizer_args: The arguments for different kinds of parameters.\n            Each item of the dictionary will be passed to the `Callable` passed to\n            `optimizer`.\n        opt_params: The parameters to be optimized. The exact values are not used,\n            only the number of elements in the list and the key of each dict.\n    \"\"\"\n    self.base_optimizer = optimizer\n\n    self.optimizers = []\n    for params in opt_params:\n        names = list(params.keys())\n        assert len(names) == 1, \"Multiple parameters were added at once.\"\n        name = names[0]\n        optimizer = self.base_optimizer(optimizer_args[name])\n        self.optimizers.append({name: optimizer})\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.init","title":"init(opt_params)","text":"

Initialize the optimizers. Equivalent to optax.optimizers.init().

Source code in jaxley/optimize/optimizer.py
def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n    \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n    opt_states = []\n    for params, optimizer in zip(opt_params, self.optimizers):\n        name = list(optimizer.keys())[0]\n        opt_state = optimizer[name].init(params)\n        opt_states.append(opt_state)\n    return opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.update","title":"update(gradient, opt_state)","text":"

Update the optimizers. Equivalent to optax.optimizers.update().

Source code in jaxley/optimize/optimizer.py
def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n    \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n    all_updates = []\n    new_opt_states = []\n    for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n        name = list(opt.keys())[0]\n        updates, new_opt_state = opt[name].update(grad, state)\n        all_updates.append(updates)\n        new_opt_states.append(new_opt_state)\n    return all_updates, new_opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.AffineTransform","title":"AffineTransform","text":"

Bases: Transform

Source code in jaxley/optimize/transforms.py
class AffineTransform(Transform):\n    def __init__(self, scale: ArrayLike, shift: ArrayLike):\n        \"\"\"This transform rescales and shifts the input.\n\n        Args:\n            scale (ArrayLike): Scaling factor.\n            shift (ArrayLike): Additive shift.\n\n        Raises:\n            ValueError: Scale needs to be larger than 0\n        \"\"\"\n        if jnp.allclose(scale, 0):\n            raise ValueError(\"a cannot be zero, must be invertible\")\n        self.a = scale\n        self.b = shift\n\n    def forward(self, x: ArrayLike) -> Array:\n        return self.a * x + self.b\n\n    def inverse(self, x: ArrayLike) -> Array:\n        return (x - self.b) / self.a\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.AffineTransform.__init__","title":"__init__(scale, shift)","text":"

This transform rescales and shifts the input.

Parameters:

Name Type Description Default scale ArrayLike

Scaling factor.

required shift ArrayLike

Additive shift.

required

Raises:

Type Description ValueError

Scale needs to be larger than 0

Source code in jaxley/optimize/transforms.py
def __init__(self, scale: ArrayLike, shift: ArrayLike):\n    \"\"\"This transform rescales and shifts the input.\n\n    Args:\n        scale (ArrayLike): Scaling factor.\n        shift (ArrayLike): Additive shift.\n\n    Raises:\n        ValueError: Scale needs to be larger than 0\n    \"\"\"\n    if jnp.allclose(scale, 0):\n        raise ValueError(\"a cannot be zero, must be invertible\")\n    self.a = scale\n    self.b = shift\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ChainTransform","title":"ChainTransform","text":"

Bases: Transform

Chaining together multiple transformations

Source code in jaxley/optimize/transforms.py
class ChainTransform(Transform):\n    \"\"\"Chaining together multiple transformations\"\"\"\n\n    def __init__(self, transforms: Sequence[Transform]) -> None:\n        \"\"\"A chain of transformations\n\n        Args:\n            transforms (Sequence[Transform]): Transforms to apply\n        \"\"\"\n        super().__init__()\n        self.transforms = transforms\n\n    def forward(self, x: ArrayLike) -> Array:\n        for transform in self.transforms:\n            x = transform(x)\n        return x\n\n    def inverse(self, y: ArrayLike) -> Array:\n        for transform in reversed(self.transforms):\n            y = transform.inverse(y)\n        return y\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ChainTransform.__init__","title":"__init__(transforms)","text":"

A chain of transformations

Parameters:

Name Type Description Default transforms Sequence[Transform]

Transforms to apply

required Source code in jaxley/optimize/transforms.py
def __init__(self, transforms: Sequence[Transform]) -> None:\n    \"\"\"A chain of transformations\n\n    Args:\n        transforms (Sequence[Transform]): Transforms to apply\n    \"\"\"\n    super().__init__()\n    self.transforms = transforms\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.CustomTransform","title":"CustomTransform","text":"

Bases: Transform

Custom transformation

Source code in jaxley/optimize/transforms.py
class CustomTransform(Transform):\n    \"\"\"Custom transformation\"\"\"\n\n    def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:\n        \"\"\"A custom transformation using a user-defined froward and\n        inverse function\n\n        Args:\n            forward_fn (Callable): Forward transformation\n            inverse_fn (Callable): Inverse transformation\n        \"\"\"\n        super().__init__()\n        self.forward_fn = forward_fn\n        self.inverse_fn = inverse_fn\n\n    def forward(self, x: ArrayLike) -> Array:\n        return self.forward_fn(x)\n\n    def inverse(self, y: ArrayLike) -> Array:\n        return self.inverse_fn(y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.CustomTransform.__init__","title":"__init__(forward_fn, inverse_fn)","text":"

A custom transformation using a user-defined froward and inverse function

Parameters:

Name Type Description Default forward_fn Callable

Forward transformation

required inverse_fn Callable

Inverse transformation

required Source code in jaxley/optimize/transforms.py
def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:\n    \"\"\"A custom transformation using a user-defined froward and\n    inverse function\n\n    Args:\n        forward_fn (Callable): Forward transformation\n        inverse_fn (Callable): Inverse transformation\n    \"\"\"\n    super().__init__()\n    self.forward_fn = forward_fn\n    self.inverse_fn = inverse_fn\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.MaskedTransform","title":"MaskedTransform","text":"

Bases: Transform

Source code in jaxley/optimize/transforms.py
class MaskedTransform(Transform):\n    def __init__(self, mask: ArrayLike, transform: Transform) -> None:\n        \"\"\"A masked transformation\n\n        Args:\n            mask (ArrayLike): Which elements to transform\n            transform (Transform): Transformation to apply\n        \"\"\"\n        super().__init__()\n        self.mask = mask\n        self.transform = transform\n\n    def forward(self, x: ArrayLike) -> Array:\n        return jnp.where(self.mask, self.transform.forward(x), x)\n\n    def inverse(self, y: ArrayLike) -> Array:\n        return jnp.where(self.mask, self.transform.inverse(y), y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.MaskedTransform.__init__","title":"__init__(mask, transform)","text":"

A masked transformation

Parameters:

Name Type Description Default mask ArrayLike

Which elements to transform

required transform Transform

Transformation to apply

required Source code in jaxley/optimize/transforms.py
def __init__(self, mask: ArrayLike, transform: Transform) -> None:\n    \"\"\"A masked transformation\n\n    Args:\n        mask (ArrayLike): Which elements to transform\n        transform (Transform): Transformation to apply\n    \"\"\"\n    super().__init__()\n    self.mask = mask\n    self.transform = transform\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.NegSoftplusTransform","title":"NegSoftplusTransform","text":"

Bases: SoftplusTransform

Negative softplus transformation.

Source code in jaxley/optimize/transforms.py
class NegSoftplusTransform(SoftplusTransform):\n    \"\"\"Negative softplus transformation.\"\"\"\n\n    def __init__(self, upper: ArrayLike) -> None:\n        \"\"\"This transform maps any value bijectively to the interval (-inf, upper].\n\n        Args:\n            upper (ArrayLike): Upper bound of the interval.\n        \"\"\"\n        super().__init__(upper)\n\n    def forward(self, x: ArrayLike) -> Array:\n        return -super().forward(-x)\n\n    def inverse(self, y: ArrayLike) -> Array:\n        return -super().inverse(-y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.NegSoftplusTransform.__init__","title":"__init__(upper)","text":"

This transform maps any value bijectively to the interval (-inf, upper].

Parameters:

Name Type Description Default upper ArrayLike

Upper bound of the interval.

required Source code in jaxley/optimize/transforms.py
def __init__(self, upper: ArrayLike) -> None:\n    \"\"\"This transform maps any value bijectively to the interval (-inf, upper].\n\n    Args:\n        upper (ArrayLike): Upper bound of the interval.\n    \"\"\"\n    super().__init__(upper)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform","title":"ParamTransform","text":"

Parameter transformation utility.

This class is used to transform parameters usually from an unconstrained space to a constrained space and back (bacause most biophysical parameter are bounded). The user can specify a PyTree of transforms that are applied to the parameters.

Attributes:

Name Type Description tf_dict

A PyTree of transforms for each parameter.

Source code in jaxley/optimize/transforms.py
class ParamTransform:\n    \"\"\"Parameter transformation utility.\n\n    This class is used to transform parameters usually from an unconstrained space to a constrained space\n    and back (bacause most biophysical parameter are bounded). The user can specify a PyTree of transforms\n    that are applied to the parameters.\n\n    Attributes:\n        tf_dict: A PyTree of transforms for each parameter.\n\n    \"\"\"\n\n    def __init__(self, tf_dict: List[Dict[str, Transform]] | Transform) -> None:\n        \"\"\"Creates a new ParamTransform object.\n\n        Args:\n            tf_dict: A PyTree of transforms for each parameter.\n        \"\"\"\n\n        self.tf_dict = tf_dict\n\n    def forward(\n        self, params: List[Dict[str, ArrayLike]] | ArrayLike\n    ) -> Dict[str, Array]:\n        \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n        Args:\n            params: A list of dictionaries (or any PyTree) with unconstrained parameters.\n\n        Returns:\n            A list of dictionaries (or any PyTree) with transformed parameters.\n\n        \"\"\"\n\n        return jax.tree_util.tree_map(lambda x, tf: tf.forward(x), params, self.tf_dict)\n\n    def inverse(\n        self, params: List[Dict[str, ArrayLike]] | ArrayLike\n    ) -> Dict[str, Array]:\n        \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n        Args:\n            params: A list of dictionaries (or any PyTree) with transformed parameters.\n\n        Returns:\n            A list of dictionaries (or any PyTree) with unconstrained parameters.\n        \"\"\"\n\n        return jax.tree_util.tree_map(lambda x, tf: tf.inverse(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.__init__","title":"__init__(tf_dict)","text":"

Creates a new ParamTransform object.

Parameters:

Name Type Description Default tf_dict List[Dict[str, Transform]] | Transform

A PyTree of transforms for each parameter.

required Source code in jaxley/optimize/transforms.py
def __init__(self, tf_dict: List[Dict[str, Transform]] | Transform) -> None:\n    \"\"\"Creates a new ParamTransform object.\n\n    Args:\n        tf_dict: A PyTree of transforms for each parameter.\n    \"\"\"\n\n    self.tf_dict = tf_dict\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.forward","title":"forward(params)","text":"

Pushes unconstrained parameters through a tf such that they fit the interval.

Parameters:

Name Type Description Default params List[Dict[str, ArrayLike]] | ArrayLike

A list of dictionaries (or any PyTree) with unconstrained parameters.

required

Returns:

Type Description Dict[str, Array]

A list of dictionaries (or any PyTree) with transformed parameters.

Source code in jaxley/optimize/transforms.py
def forward(\n    self, params: List[Dict[str, ArrayLike]] | ArrayLike\n) -> Dict[str, Array]:\n    \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n    Args:\n        params: A list of dictionaries (or any PyTree) with unconstrained parameters.\n\n    Returns:\n        A list of dictionaries (or any PyTree) with transformed parameters.\n\n    \"\"\"\n\n    return jax.tree_util.tree_map(lambda x, tf: tf.forward(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.inverse","title":"inverse(params)","text":"

Takes parameters from within the interval and makes them unconstrained.

Parameters:

Name Type Description Default params List[Dict[str, ArrayLike]] | ArrayLike

A list of dictionaries (or any PyTree) with transformed parameters.

required

Returns:

Type Description Dict[str, Array]

A list of dictionaries (or any PyTree) with unconstrained parameters.

Source code in jaxley/optimize/transforms.py
def inverse(\n    self, params: List[Dict[str, ArrayLike]] | ArrayLike\n) -> Dict[str, Array]:\n    \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n    Args:\n        params: A list of dictionaries (or any PyTree) with transformed parameters.\n\n    Returns:\n        A list of dictionaries (or any PyTree) with unconstrained parameters.\n    \"\"\"\n\n    return jax.tree_util.tree_map(lambda x, tf: tf.inverse(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SigmoidTransform","title":"SigmoidTransform","text":"

Bases: Transform

Sigmoid transformation.

Source code in jaxley/optimize/transforms.py
class SigmoidTransform(Transform):\n    \"\"\"Sigmoid transformation.\"\"\"\n\n    def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:\n        \"\"\"This transform maps any value bijectively to the interval [lower, upper].\n\n        Args:\n            lower (ArrayLike): Lower bound of the interval.\n            upper (ArrayLike): Upper bound of the interval.\n        \"\"\"\n        super().__init__()\n        self.lower = lower\n        self.width = upper - lower\n\n    def forward(self, x: ArrayLike) -> Array:\n        y = 1.0 / (1.0 + save_exp(-x))\n        return self.lower + self.width * y\n\n    def inverse(self, y: ArrayLike) -> Array:\n        x = (y - self.lower) / self.width\n        x = -jnp.log((1.0 / x) - 1.0)\n        return x\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SigmoidTransform.__init__","title":"__init__(lower, upper)","text":"

This transform maps any value bijectively to the interval [lower, upper].

Parameters:

Name Type Description Default lower ArrayLike

Lower bound of the interval.

required upper ArrayLike

Upper bound of the interval.

required Source code in jaxley/optimize/transforms.py
def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:\n    \"\"\"This transform maps any value bijectively to the interval [lower, upper].\n\n    Args:\n        lower (ArrayLike): Lower bound of the interval.\n        upper (ArrayLike): Upper bound of the interval.\n    \"\"\"\n    super().__init__()\n    self.lower = lower\n    self.width = upper - lower\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SoftplusTransform","title":"SoftplusTransform","text":"

Bases: Transform

Softplus transformation.

Source code in jaxley/optimize/transforms.py
class SoftplusTransform(Transform):\n    \"\"\"Softplus transformation.\"\"\"\n\n    def __init__(self, lower: ArrayLike) -> None:\n        \"\"\"This transform maps any value bijectively to the interval [lower, inf).\n\n        Args:\n            lower (ArrayLike): Lower bound of the interval.\n        \"\"\"\n        super().__init__()\n        self.lower = lower\n\n    def forward(self, x: ArrayLike) -> Array:\n        return jnp.log1p(save_exp(x)) + self.lower\n\n    def inverse(self, y: ArrayLike) -> Array:\n        return jnp.log(save_exp(y - self.lower) - 1.0)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SoftplusTransform.__init__","title":"__init__(lower)","text":"

This transform maps any value bijectively to the interval [lower, inf).

Parameters:

Name Type Description Default lower ArrayLike

Lower bound of the interval.

required Source code in jaxley/optimize/transforms.py
def __init__(self, lower: ArrayLike) -> None:\n    \"\"\"This transform maps any value bijectively to the interval [lower, inf).\n\n    Args:\n        lower (ArrayLike): Lower bound of the interval.\n    \"\"\"\n    super().__init__()\n    self.lower = lower\n
"},{"location":"reference/utils/","title":"Utils","text":""},{"location":"reference/utils/#jaxley.utils.cell_utils.build_radiuses_from_xyzr","title":"build_radiuses_from_xyzr(radius_fns, branch_indices, min_radius, ncomp)","text":"

Return the radiuses of branches given SWC file xyzr.

Returns an array of shape (num_branches, ncomp).

Parameters:

Name Type Description Default radius_fns List[Callable]

Functions which, given compartment locations return the radius.

required branch_indices List[int]

The indices of the branches for which to return the radiuses.

required min_radius Optional[float]

If passed, the radiuses are clipped to be at least as large.

required ncomp int

The number of compartments that every branch is discretized into.

required Source code in jaxley/utils/cell_utils.py
def build_radiuses_from_xyzr(\n    radius_fns: List[Callable],\n    branch_indices: List[int],\n    min_radius: Optional[float],\n    ncomp: int,\n) -> jnp.ndarray:\n    \"\"\"Return the radiuses of branches given SWC file xyzr.\n\n    Returns an array of shape `(num_branches, ncomp)`.\n\n    Args:\n        radius_fns: Functions which, given compartment locations return the radius.\n        branch_indices: The indices of the branches for which to return the radiuses.\n        min_radius: If passed, the radiuses are clipped to be at least as large.\n        ncomp: The number of compartments that every branch is discretized into.\n    \"\"\"\n    # Compartment locations are at the center of the internal nodes.\n    non_split = 1 / ncomp\n    range_ = np.linspace(non_split / 2, 1 - non_split / 2, ncomp)\n\n    # Build radiuses.\n    radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices])\n    radiuses_each = radiuses.ravel(order=\"C\")\n    if min_radius is None:\n        assert np.all(\n            radiuses_each > 0.0\n        ), \"Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`.\"\n    else:\n        radiuses_each[radiuses_each < min_radius] = min_radius\n\n    return radiuses_each\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_axial_conductances","title":"compute_axial_conductances(comp_edges, params)","text":"

Given comp_edges, radius, length, r_a, cm, compute the axial conductances.

Note that the resulting axial conductances will already by divided by the capacitance cm.

Source code in jaxley/utils/cell_utils.py
def compute_axial_conductances(\n    comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray]\n) -> jnp.ndarray:\n    \"\"\"Given `comp_edges`, radius, length, r_a, cm, compute the axial conductances.\n\n    Note that the resulting axial conductances will already by divided by the\n    capacitance `cm`.\n    \"\"\"\n    # `Compartment-to-compartment` (c2c) axial coupling conductances.\n    condition = comp_edges[\"type\"].to_numpy() == 0\n    source_comp_inds = np.asarray(comp_edges[condition][\"source\"].to_list())\n    sink_comp_inds = np.asarray(comp_edges[condition][\"sink\"].to_list())\n\n    if len(sink_comp_inds) > 0:\n        conds_c2c = (\n            vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))(\n                params[\"radius\"][sink_comp_inds],\n                params[\"radius\"][source_comp_inds],\n                params[\"axial_resistivity\"][sink_comp_inds],\n                params[\"axial_resistivity\"][source_comp_inds],\n                params[\"length\"][sink_comp_inds],\n                params[\"length\"][source_comp_inds],\n            )\n            / params[\"capacitance\"][sink_comp_inds]\n        )\n    else:\n        conds_c2c = jnp.asarray([])\n\n    # `branchpoint-to-compartment` (bp2c) axial coupling conductances.\n    condition = comp_edges[\"type\"].isin([1, 2])\n    sink_comp_inds = np.asarray(comp_edges[condition][\"sink\"].to_list())\n\n    if len(sink_comp_inds) > 0:\n        conds_bp2c = (\n            vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0))(\n                params[\"radius\"][sink_comp_inds],\n                params[\"axial_resistivity\"][sink_comp_inds],\n                params[\"length\"][sink_comp_inds],\n            )\n            / params[\"capacitance\"][sink_comp_inds]\n        )\n    else:\n        conds_bp2c = jnp.asarray([])\n\n    # `compartment-to-branchpoint` (c2bp) axial coupling conductances.\n    condition = comp_edges[\"type\"].isin([3, 4])\n    source_comp_inds = np.asarray(comp_edges[condition][\"source\"].to_list())\n\n    if len(source_comp_inds) > 0:\n        conds_c2bp = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n            params[\"radius\"][source_comp_inds],\n            params[\"axial_resistivity\"][source_comp_inds],\n            params[\"length\"][source_comp_inds],\n        )\n        # For numerical stability. These values are very small, but their scale\n        # does not matter.\n        conds_c2bp *= 1_000\n    else:\n        conds_c2bp = jnp.asarray([])\n\n    # All axial coupling conductances.\n    return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp])\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_children_and_parents","title":"compute_children_and_parents(branch_edges)","text":"

Build indices used during `._init_morph_custom_spsolve().

Source code in jaxley/utils/cell_utils.py
def compute_children_and_parents(\n    branch_edges: pd.DataFrame,\n) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]:\n    \"\"\"Build indices used during `._init_morph_custom_spsolve().\"\"\"\n    par_inds = branch_edges[\"parent_branch_index\"].to_numpy()\n    child_inds = branch_edges[\"child_branch_index\"].to_numpy()\n    child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n    par_inds = np.unique(par_inds)\n    return par_inds, child_inds, child_belongs_to_branchpoint\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_children_indices","title":"compute_children_indices(parents)","text":"

Return all children indices of every branch.

Example:

parents = [-1, 0, 0]\ncompute_children_indices(parents) -> [[1, 2], [], []]\n

Source code in jaxley/utils/cell_utils.py
def compute_children_indices(parents) -> List[jnp.ndarray]:\n    \"\"\"Return all children indices of every branch.\n\n    Example:\n    ```\n    parents = [-1, 0, 0]\n    compute_children_indices(parents) -> [[1, 2], [], []]\n    ```\n    \"\"\"\n    num_branches = len(parents)\n    child_indices = []\n    for b in range(num_branches):\n        child_indices.append(np.where(parents == b)[0])\n    return child_indices\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond","title":"compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2)","text":"

Return the coupling conductance between two compartments.

Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models.

radius: um r_a: ohm cm length_single_compartment: um coupling_conds: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2

Source code in jaxley/utils/cell_utils.py
def compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2):\n    \"\"\"Return the coupling conductance between two compartments.\n\n    Equations taken from `https://en.wikipedia.org/wiki/Compartmental_neuron_models`.\n\n    `radius`: um\n    `r_a`: ohm cm\n    `length_single_compartment`: um\n    `coupling_conds`: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2\n    \"\"\"\n    # Multiply by 10**7 to convert (S / cm / um) -> (mS / cm^2).\n    return rad1 * rad2**2 / (r_a1 * rad2**2 * l1 + r_a2 * rad1**2 * l2) / l1 * 10**7\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond_branchpoint","title":"compute_coupling_cond_branchpoint(rad, r_a, l)","text":"

Return the coupling conductance between one compartment and a comp with l=0.

From https://en.wikipedia.org/wiki/Compartmental_neuron_models

If one compartment has l=0.0 then the equations simplify.

R_long = \\sum_i r_a * L_i/2 / crosssection_i

with crosssection = pi * r**2

For a single compartment with L>0, this turns into: R_long = r_a * L/2 / crosssection

Then, g_long = crosssection * 2 / L / r_a

Then, the effective conductance is g_long / zylinder_area. So: g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L g = r / r_a / L**2

Source code in jaxley/utils/cell_utils.py
def compute_coupling_cond_branchpoint(rad, r_a, l):\n    r\"\"\"Return the coupling conductance between one compartment and a comp with l=0.\n\n    From https://en.wikipedia.org/wiki/Compartmental_neuron_models\n\n    If one compartment has l=0.0 then the equations simplify.\n\n    R_long = \\sum_i r_a * L_i/2 / crosssection_i\n\n    with crosssection = pi * r**2\n\n    For a single compartment with L>0, this turns into:\n    R_long = r_a * L/2 / crosssection\n\n    Then, g_long = crosssection * 2 / L / r_a\n\n    Then, the effective conductance is g_long / zylinder_area. So:\n    g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L\n    g = r / r_a / L**2\n    \"\"\"\n    return rad / r_a / l**2 * 10**7  # Convert (S / cm / um) -> (mS / cm^2)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_impact_on_node","title":"compute_impact_on_node(rad, r_a, l)","text":"

Compute the weight with which a compartment influences its node.

In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0

Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a

This equation can be multiplied by any constant.

Source code in jaxley/utils/cell_utils.py
def compute_impact_on_node(rad, r_a, l):\n    r\"\"\"Compute the weight with which a compartment influences its node.\n\n    In order to satisfy Kirchhoffs current law, the current at a branch point must be\n    proportional to the crosssection of the compartment. We only require proportionality\n    here because the branch point equation reads:\n    `g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0`\n\n    Because R_long = r_a * L/2 / crosssection, we get\n    g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a\n\n    This equation can be multiplied by any constant.\"\"\"\n    return rad**2 / r_a / l\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_morphology_indices_in_levels","title":"compute_morphology_indices_in_levels(num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds)","text":"

Return (row, col) to build the sparse matrix defining the voltage eqs.

This is run at init, not during runtime.

Source code in jaxley/utils/cell_utils.py
def compute_morphology_indices_in_levels(\n    num_branchpoints,\n    child_belongs_to_branchpoint,\n    par_inds,\n    child_inds,\n):\n    \"\"\"Return (row, col) to build the sparse matrix defining the voltage eqs.\n\n    This is run at `init`, not during runtime.\n    \"\"\"\n    branchpoint_inds_parents = jnp.arange(num_branchpoints)\n    branchpoint_inds_children = child_belongs_to_branchpoint\n    branch_inds_parents = par_inds\n    branch_inds_children = child_inds\n\n    children = jnp.stack([branch_inds_children, branchpoint_inds_children])\n    parents = jnp.stack([branch_inds_parents, branchpoint_inds_parents])\n\n    return {\"children\": children.T, \"parents\": parents.T}\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.convert_point_process_to_distributed","title":"convert_point_process_to_distributed(current, radius, length)","text":"

Convert current point process (nA) to distributed current (uA/cm2).

This function gets called for synapses and for external stimuli.

Parameters:

Name Type Description Default current ndarray

Current in nA.

required radius ndarray

Compartment radius in um.

required length ndarray

Compartment length in um.

required Return

Current in uA/cm2.

Source code in jaxley/utils/cell_utils.py
def convert_point_process_to_distributed(\n    current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray\n) -> jnp.ndarray:\n    \"\"\"Convert current point process (nA) to distributed current (uA/cm2).\n\n    This function gets called for synapses and for external stimuli.\n\n    Args:\n        current: Current in `nA`.\n        radius: Compartment radius in `um`.\n        length: Compartment length in `um`.\n\n    Return:\n        Current in `uA/cm2`.\n    \"\"\"\n    area = 2 * pi * radius * length\n    current /= area  # nA / um^2\n    return current * 100_000  # Convert (nA / um^2) to (uA / cm^2)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.equal_segments","title":"equal_segments(branch_property, ncomp_per_branch)","text":"

Generates segments where some property is the same in each segment.

Parameters:

Name Type Description Default branch_property list

List of values of the property in each branch. Should have len(branch_property) == num_branches.

required Source code in jaxley/utils/cell_utils.py
def equal_segments(branch_property: list, ncomp_per_branch: int):\n    \"\"\"Generates segments where some property is the same in each segment.\n\n    Args:\n        branch_property: List of values of the property in each branch. Should have\n            `len(branch_property) == num_branches`.\n    \"\"\"\n    assert isinstance(branch_property, list), \"branch_property must be a list.\"\n    return jnp.asarray([branch_property] * ncomp_per_branch).T\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.get_num_neighbours","title":"get_num_neighbours(num_children, ncomp_per_branch, num_branches)","text":"

Number of neighbours of each compartment.

Source code in jaxley/utils/cell_utils.py
def get_num_neighbours(\n    num_children: jnp.ndarray,\n    ncomp_per_branch: int,\n    num_branches: int,\n):\n    \"\"\"\n    Number of neighbours of each compartment.\n    \"\"\"\n    num_neighbours = 2 * jnp.ones((num_branches * ncomp_per_branch))\n    num_neighbours = num_neighbours.at[ncomp_per_branch - 1].set(1.0)\n    num_neighbours = num_neighbours.at[jnp.arange(num_branches) * ncomp_per_branch].set(\n        num_children + 1.0\n    )\n    return num_neighbours\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.group_and_sum","title":"group_and_sum(values_to_sum, inds_to_group_by, num_branchpoints)","text":"

Group values by whether they have the same integer and sum values within group.

This is used to construct the last diagonals at the branch points.

Written by ChatGPT.

Source code in jaxley/utils/cell_utils.py
def group_and_sum(\n    values_to_sum: jnp.ndarray, inds_to_group_by: jnp.ndarray, num_branchpoints: int\n) -> jnp.ndarray:\n    \"\"\"Group values by whether they have the same integer and sum values within group.\n\n    This is used to construct the last diagonals at the branch points.\n\n    Written by ChatGPT.\n    \"\"\"\n    # Initialize an array to hold the sum of each group\n    group_sums = jnp.zeros(num_branchpoints)\n\n    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n    # `len(inds) == 0` is the case for branches and compartments.\n    if num_branchpoints > 0:\n        group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)\n\n    return group_sums\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.interpolate_xyzr","title":"interpolate_xyzr(loc, coords)","text":"

Perform a linear interpolation between xyz-coordinates.

Parameters:

Name Type Description Default loc float

The location in [0,1] along the branch.

required coords ndarray

Array containing the reconstructed xyzr points of the branch.

required Return

Interpolated xyz coordinate at loc, shape `(3,).

Source code in jaxley/utils/cell_utils.py
def interpolate_xyzr(loc: float, coords: np.ndarray):\n    \"\"\"Perform a linear interpolation between xyz-coordinates.\n\n    Args:\n        loc: The location in [0,1] along the branch.\n        coords: Array containing the reconstructed xyzr points of the branch.\n\n    Return:\n        Interpolated xyz coordinate at `loc`, shape `(3,).\n    \"\"\"\n    dl = np.sqrt(np.sum(np.diff(coords[:, :3], axis=0) ** 2, axis=1))\n    pathlens = np.insert(np.cumsum(dl), 0, 0)  # cummulative length of sections\n    norm_pathlens = pathlens / np.maximum(1e-8, pathlens[-1])  # norm lengths to [0,1].\n\n    return v_interp(loc, norm_pathlens, coords)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.linear_segments","title":"linear_segments(initial_val, endpoint_vals, parents, ncomp_per_branch)","text":"

Generates segments where some property is linearly interpolated.

Parameters:

Name Type Description Default initial_val float

The value at the tip of the soma.

required endpoint_vals list

The value at the endpoints of each branch.

required Source code in jaxley/utils/cell_utils.py
def linear_segments(\n    initial_val: float, endpoint_vals: list, parents: jnp.ndarray, ncomp_per_branch: int\n):\n    \"\"\"Generates segments where some property is linearly interpolated.\n\n    Args:\n        initial_val: The value at the tip of the soma.\n        endpoint_vals: The value at the endpoints of each branch.\n    \"\"\"\n    branch_property = endpoint_vals + [initial_val]\n    num_branches = len(parents)\n    # Compute radiuses by linear interpolation.\n    endpoint_radiuses = jnp.asarray(branch_property)\n\n    def compute_rad(branch_ind, loc):\n        start = endpoint_radiuses[parents[branch_ind]]\n        end = endpoint_radiuses[branch_ind]\n        return (end - start) * loc + start\n\n    branch_inds_of_each_comp = jnp.tile(jnp.arange(num_branches), ncomp_per_branch)\n    locs_of_each_comp = jnp.linspace(1, 0, ncomp_per_branch).repeat(num_branches)\n    rad_of_each_comp = compute_rad(branch_inds_of_each_comp, locs_of_each_comp)\n\n    return jnp.reshape(rad_of_each_comp, (ncomp_per_branch, num_branches)).T\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.loc_of_index","title":"loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch)","text":"

Return location corresponding to global compartment index.

Source code in jaxley/utils/cell_utils.py
def loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch):\n    \"\"\"Return location corresponding to global compartment index.\"\"\"\n    cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n    index = global_comp_index - cumsum_ncomp[global_branch_index]\n    ncomp = ncomp_per_branch[global_branch_index]\n    return (0.5 + index) / ncomp\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.local_index_of_loc","title":"local_index_of_loc(loc, global_branch_ind, ncomp_per_branch)","text":"

Returns the local index of a comp given a loc [0, 1] and the index of a branch.

This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here.

Parameters:

Name Type Description Default branch_ind

Index of the branch.

required loc float

Location (in [0, 1]) along that branch.

required ncomp_per_branch int

Number of segments of each branch.

required

Returns:

Type Description int

The local index of the compartment.

Source code in jaxley/utils/cell_utils.py
def local_index_of_loc(\n    loc: float, global_branch_ind: int, ncomp_per_branch: int\n) -> int:\n    \"\"\"Returns the local index of a comp given a loc [0, 1] and the index of a branch.\n\n    This is used because we specify locations such as synapses as a value between 0 and\n    1. We have to convert this onto a discrete segment here.\n\n    Args:\n        branch_ind: Index of the branch.\n        loc: Location (in [0, 1]) along that branch.\n        ncomp_per_branch: Number of segments of each branch.\n\n    Returns:\n        The local index of the compartment.\n    \"\"\"\n    ncomp = ncomp_per_branch[global_branch_ind]  # only for convenience.\n    possible_locs = np.linspace(0.5 / ncomp, 1 - 0.5 / ncomp, ncomp)\n    ind_along_branch = np.argmin(np.abs(possible_locs - loc))\n    return ind_along_branch\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.merge_cells","title":"merge_cells(cumsum_num_branches, cumsum_num_branchpoints, arrs, exclude_first=True)","text":"

Build full list of which branches are solved in which iteration.

From the branching pattern of single cells, this \u201cmerges\u201d them into a single ordering of branches.

Parameters:

Name Type Description Default cumsum_num_branches List[int]

cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing [0, 10, 25, 30].

required arrs List[List[ndarray]]

A list of a list of arrays that should be merged.

required exclude_first bool

If True, the first element of each list in arrs will remain unchanged. Useful if a -1 (which indicates \u201cno parent\u201d) entry should not be changed.

True

Returns:

Type Description ndarray

A list of arrays which contain the branch indices that are computed at each

ndarray

level (i.e., iteration).

Source code in jaxley/utils/cell_utils.py
def merge_cells(\n    cumsum_num_branches: List[int],\n    cumsum_num_branchpoints: List[int],\n    arrs: List[List[np.ndarray]],\n    exclude_first: bool = True,\n) -> np.ndarray:\n    \"\"\"\n    Build full list of which branches are solved in which iteration.\n\n    From the branching pattern of single cells, this \"merges\" them into a single\n    ordering of branches.\n\n    Args:\n        cumsum_num_branches: cumulative number of branches. E.g., for three cells with\n            10, 15, and 5 branches respectively, this will should be a list containing\n            `[0, 10, 25, 30]`.\n        arrs: A list of a list of arrays that should be merged.\n        exclude_first: If `True`, the first element of each list in `arrs` will remain\n            unchanged. Useful if a `-1` (which indicates \"no parent\") entry should not\n            be changed.\n\n    Returns:\n        A list of arrays which contain the branch indices that are computed at each\n        level (i.e., iteration).\n    \"\"\"\n    ps = []\n    for i, att in enumerate(arrs):\n        p = att\n        if exclude_first:\n            raise NotImplementedError\n            p = [p[0]] + [p_in_level + cumsum_num_branches[i] for p_in_level in p[1:]]\n        else:\n            p = [\n                p_in_level\n                + np.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]])\n                for p_in_level in p\n            ]\n        ps.append(p)\n\n    max_len = max([len(att) for att in arrs])\n    combined_parents_in_level = []\n    for i in range(max_len):\n        current_ps = []\n        for p in ps:\n            if len(p) > i:\n                current_ps.append(p[i])\n        combined_parents_in_level.append(np.concatenate(current_ps))\n\n    return combined_parents_in_level\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.params_to_pstate","title":"params_to_pstate(params, indices_set_by_trainables)","text":"

Make outputs get_parameters() conform with outputs of .data_set().

make_trainable() followed by params=get_parameters() does not return indices because these indices would also be differentiated by jax.grad (as soon as the params are passed to def simulate(params). Therefore, in jx.integrate, we run the function to add indices to the dict. The outputs of params_to_pstate are of the same shape as the outputs of .data_set().

Source code in jaxley/utils/cell_utils.py
def params_to_pstate(\n    params: List[Dict[str, jnp.ndarray]],\n    indices_set_by_trainables: List[jnp.ndarray],\n):\n    \"\"\"Make outputs `get_parameters()` conform with outputs of `.data_set()`.\n\n    `make_trainable()` followed by `params=get_parameters()` does not return indices\n    because these indices would also be differentiated by `jax.grad` (as soon as\n    the `params` are passed to `def simulate(params)`. Therefore, in `jx.integrate`,\n    we run the function to add indices to the dict. The outputs of `params_to_pstate`\n    are of the same shape as the outputs of `.data_set()`.\"\"\"\n    return [\n        {\"key\": list(p.keys())[0], \"val\": list(p.values())[0], \"indices\": i}\n        for p, i in zip(params, indices_set_by_trainables)\n    ]\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.query_channel_states_and_params","title":"query_channel_states_and_params(d, keys, idcs)","text":"

Get dict with subset of keys and values from d.

This is used to restrict a dict where every item contains all states to only the ones that are relevant for the channel. E.g.

states = {'eCa': Array([ 0., 0., nan]}

will be states = {'eCa': Array([ 0., 0.]}

Only loops over necessary keys, as opposed to looping over d.items().

Source code in jaxley/utils/cell_utils.py
def query_channel_states_and_params(d, keys, idcs):\n    \"\"\"Get dict with subset of keys and values from d.\n\n    This is used to restrict a dict where every item contains __all__ states to only\n    the ones that are relevant for the channel. E.g.\n\n    ```states = {'eCa': Array([ 0.,  0., nan]}```\n\n    will be\n    ```states = {'eCa': Array([ 0.,  0.]}```\n\n    Only loops over necessary keys, as opposed to looping over `d.items()`.\"\"\"\n    return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.remap_to_consecutive","title":"remap_to_consecutive(arr)","text":"

Maps an array of integers to an array of consecutive integers.

E.g. [0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]

Source code in jaxley/utils/cell_utils.py
def remap_to_consecutive(arr):\n    \"\"\"Maps an array of integers to an array of consecutive integers.\n\n    E.g. `[0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]`\n    \"\"\"\n    _, inverse_indices = jnp.unique(arr, return_inverse=True)\n    return inverse_indices\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.compute_rotation_matrix","title":"compute_rotation_matrix(axis, angle)","text":"

Return the rotation matrix associated with counterclockwise rotation about the given axis by the given angle.

Can be used to rotate a coordinate vector by multiplying it with the rotation matrix.

Parameters:

Name Type Description Default axis ndarray

The axis of rotation.

required angle float

The angle of rotation in radians.

required

Returns:

Type Description ndarray

A 3x3 rotation matrix.

Source code in jaxley/utils/plot_utils.py
def compute_rotation_matrix(axis: ndarray, angle: float) -> ndarray:\n    \"\"\"\n    Return the rotation matrix associated with counterclockwise rotation about\n    the given axis by the given angle.\n\n    Can be used to rotate a coordinate vector by multiplying it with the rotation\n    matrix.\n\n    Args:\n        axis: The axis of rotation.\n        angle: The angle of rotation in radians.\n\n    Returns:\n        A 3x3 rotation matrix.\n    \"\"\"\n    axis = axis / np.sqrt(np.dot(axis, axis))\n    a = np.cos(angle / 2.0)\n    b, c, d = -axis * np.sin(angle / 2.0)\n    aa, bb, cc, dd = a * a, b * b, c * c, d * d\n    bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d\n    return np.array(\n        [\n            [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],\n            [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],\n            [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],\n        ]\n    )\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_cone_frustum_mesh","title":"create_cone_frustum_mesh(length, radius_bottom, radius_top, bottom_dome=False, top_dome=False, resolution=100)","text":"

Generates mesh points for a cone frustum, with optional domes at either end.

This is used to render the traced morphology in 3D (and to project it to 2D) as part of plot_morph. Sections between two traced coordinates with two different radii can be represented by a cone frustum. Additionally, the ends of the frustum can be capped with hemispheres to ensure that two neighbouring frustums are connected smoothly (like ball joints).

Parameters:

Name Type Description Default length float

The length of the frustum.

required radius_bottom float

The radius of the bottom of the frustum.

required radius_top float

The radius of the top of the frustum.

required bottom_dome bool

If True, a dome is added to the bottom of the frustum. The dome is a hemisphere with radius radius_bottom.

False top_dome bool

If True, a dome is added to the top of the frustum. The dome is a hemisphere with radius radius_top.

False resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description ndarray

An array of mesh points.

Source code in jaxley/utils/plot_utils.py
def create_cone_frustum_mesh(\n    length: float,\n    radius_bottom: float,\n    radius_top: float,\n    bottom_dome: bool = False,\n    top_dome: bool = False,\n    resolution: int = 100,\n) -> ndarray:\n    \"\"\"Generates mesh points for a cone frustum, with optional domes at either end.\n\n    This is used to render the traced morphology in 3D (and to project it to 2D)\n    as part of `plot_morph`. Sections between two traced coordinates with two\n    different radii can be represented by a cone frustum. Additionally, the ends\n    of the frustum can be capped with hemispheres to ensure that two neighbouring\n    frustums are connected smoothly (like ball joints).\n\n    Args:\n        length: The length of the frustum.\n        radius_bottom: The radius of the bottom of the frustum.\n        radius_top: The radius of the top of the frustum.\n        bottom_dome: If True, a dome is added to the bottom of the frustum.\n            The dome is a hemisphere with radius `radius_bottom`.\n        top_dome: If True, a dome is added to the top of the frustum.\n            The dome is a hemisphere with radius `radius_top`.\n        resolution: defines the resolution of the mesh.\n            If too low (typically <10), can result in errors.\n            Useful too have a simpler mesh for plotting.\n\n    Returns:\n        An array of mesh points.\n    \"\"\"\n\n    t = np.linspace(0, 2 * np.pi, resolution)\n\n    # Determine the total height including domes\n    total_height = length\n    total_height += radius_bottom if bottom_dome else 0\n    total_height += radius_top if top_dome else 0\n\n    z = np.linspace(0, total_height, resolution)\n    t_grid, z_coords = np.meshgrid(t, z)\n\n    # Initialize arrays\n    x_coords = np.zeros_like(t_grid)\n    y_coords = np.zeros_like(t_grid)\n    r_coords = np.zeros_like(t_grid)\n\n    # Bottom hemisphere\n    if bottom_dome:\n        dome_mask = z_coords < radius_bottom\n        arg = 1 - z_coords[dome_mask] / radius_bottom\n        arg[np.isclose(arg, 1, atol=1e-6, rtol=1e-6)] = 1\n        arg[np.isclose(arg, -1, atol=1e-6, rtol=1e-6)] = -1\n        phi = np.arccos(1 - z_coords[dome_mask] / radius_bottom)\n        r_coords[dome_mask] = radius_bottom * np.sin(phi)\n        z_coords[dome_mask] = z_coords[dome_mask]\n\n    # Frustum\n    frustum_start = radius_bottom if bottom_dome else 0\n    frustum_end = total_height - (radius_top if top_dome else 0)\n    frustum_mask = (z_coords >= frustum_start) & (z_coords <= frustum_end)\n    z_frustum = z_coords[frustum_mask] - frustum_start\n    r_coords[frustum_mask] = radius_bottom + (radius_top - radius_bottom) * (\n        z_frustum / length\n    )\n\n    # Top hemisphere\n    if top_dome:\n        dome_mask = z_coords > (total_height - radius_top)\n        arg = (z_coords[dome_mask] - (total_height - radius_top)) / radius_top\n        arg[np.isclose(arg, 1, atol=1e-6, rtol=1e-6)] = 1\n        arg[np.isclose(arg, -1, atol=1e-6, rtol=1e-6)] = -1\n        phi = np.arccos(arg)\n        r_coords[dome_mask] = radius_top * np.sin(phi)\n\n    x_coords = r_coords * np.cos(t_grid)\n    y_coords = r_coords * np.sin(t_grid)\n\n    return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_cylinder_mesh","title":"create_cylinder_mesh(length, radius, resolution=100)","text":"

Generates mesh points for a cylinder.

This is used to render cylindrical compartments in 3D (and to project it to 2D) as part of plot_comps.

Parameters:

Name Type Description Default length float

The length of the cylinder.

required radius float

The radius of the cylinder.

required resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description ndarray

An array of mesh points.

Source code in jaxley/utils/plot_utils.py
def create_cylinder_mesh(\n    length: float, radius: float, resolution: int = 100\n) -> ndarray:\n    \"\"\"Generates mesh points for a cylinder.\n\n    This is used to render cylindrical compartments in 3D (and to project it to 2D)\n    as part of `plot_comps`.\n\n    Args:\n        length: The length of the cylinder.\n        radius: The radius of the cylinder.\n        resolution: defines the resolution of the mesh.\n            If too low (typically <10), can result in errors.\n            Useful too have a simpler mesh for plotting.\n\n    Returns:\n        An array of mesh points.\n    \"\"\"\n    # Define cylinder\n    t = np.linspace(0, 2 * np.pi, resolution)\n    z_coords = np.linspace(-length / 2, length / 2, resolution)\n    t_grid, z_coords = np.meshgrid(t, z_coords)\n\n    x_coords = radius * np.cos(t_grid)\n    y_coords = radius * np.sin(t_grid)\n    return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_sphere_mesh","title":"create_sphere_mesh(radius, resolution=100)","text":"

Generates mesh points for a sphere.

This is used to render spherical compartments in 3D (and to project it to 2D) as part of plot_comps.

Parameters:

Name Type Description Default radius float

The radius of the sphere.

required resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description ndarray

An array of mesh points.

Source code in jaxley/utils/plot_utils.py
def create_sphere_mesh(radius: float, resolution: int = 100) -> np.ndarray:\n    \"\"\"Generates mesh points for a sphere.\n\n    This is used to render spherical compartments in 3D (and to project it to 2D)\n    as part of `plot_comps`.\n\n    Args:\n        radius: The radius of the sphere.\n        resolution: defines the resolution of the mesh.\n            If too low (typically <10), can result in errors.\n            Useful too have a simpler mesh for plotting.\n\n    Returns:\n        An array of mesh points.\n    \"\"\"\n    phi = np.linspace(0, np.pi, resolution)\n    theta = np.linspace(0, 2 * np.pi, resolution)\n\n    # Create a 2D meshgrid for phi and theta\n    phi_coords, theta_coords = np.meshgrid(phi, theta)\n\n    # Convert spherical coordinates to Cartesian coordinates\n    x_coords = radius * np.sin(phi_coords) * np.cos(theta_coords)\n    y_coords = radius * np.sin(phi_coords) * np.sin(theta_coords)\n    z_coords = radius * np.cos(phi_coords)\n\n    return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.extract_outline","title":"extract_outline(points)","text":"

Get the outline of a 2D/3D shape.

Extracts the subset of points which form the convex hull, i.e. the outline of the input points.

Parameters:

Name Type Description Default points ndarray

An array of points / corrdinates.

required

Returns:

Type Description ndarray

An array of points which form the convex hull.

Source code in jaxley/utils/plot_utils.py
def extract_outline(points: ndarray) -> ndarray:\n    \"\"\"Get the outline of a 2D/3D shape.\n\n    Extracts the subset of points which form the convex hull, i.e. the outline of\n    the input points.\n\n    Args:\n        points: An array of points / corrdinates.\n\n    Returns:\n        An array of points which form the convex hull.\n    \"\"\"\n    hull = ConvexHull(points)\n    hull_points = points[hull.vertices]\n    return hull_points\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_comps","title":"plot_comps(module_or_view, dims=(0, 1), color='k', ax=None, true_comp_length=True, resolution=100, **kwargs)","text":"

Plot compartmentalized neural morphology.

Plots the projection of the cylindrical compartments.

Parameters:

Name Type Description Default module_or_view Union[Module, View]

The module or view to plot.

required dims Tuple[int]

The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.

(0, 1) color str

The color for all compartments

'k' ax Optional[Axes]

The matplotlib axis to plot on.

None true_comp_length bool

If True, the length of the compartment is used, i.e. the length of the traced neurite. This means for zig-zagging neurites the cylinders will be longer than the straight-line distance between the start and end point of the neurite. This can lead to overlapping and miss-aligned cylinders. Setting this False will use the straight-line distance instead for nicer plots.

True resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100 kwargs

The plot kwargs for plt.fill.

{}

Returns:

Type Description Axes

Plot of the compartmentalized morphology.

Source code in jaxley/utils/plot_utils.py
def plot_comps(\n    module_or_view: Union[\"jx.Module\", \"jx.View\"],\n    dims: Tuple[int] = (0, 1),\n    color: str = \"k\",\n    ax: Optional[Axes] = None,\n    true_comp_length: bool = True,\n    resolution: int = 100,\n    **kwargs,\n) -> Axes:\n    \"\"\"Plot compartmentalized neural morphology.\n\n    Plots the projection of the cylindrical compartments.\n\n    Args:\n        module_or_view: The module or view to plot.\n        dims: The dimensions to plot / to project the cylinder onto,\n            i.e. [0,1] xy-plane or [0,1,2] for 3D.\n        color: The color for all compartments\n        ax: The matplotlib axis to plot on.\n        true_comp_length: If True, the length of the compartment is used, i.e. the\n            length of the traced neurite. This means for zig-zagging neurites the\n            cylinders will be longer than the straight-line distance between the\n            start and end point of the neurite. This can lead to overlapping and\n            miss-aligned cylinders. Setting this False will use the straight-line\n            distance instead for nicer plots.\n        resolution: defines the resolution of the mesh.\n            If too low (typically <10), can result in errors.\n            Useful too have a simpler mesh for plotting.\n        kwargs: The plot kwargs for plt.fill.\n\n    Returns:\n        Plot of the compartmentalized morphology.\n    \"\"\"\n    if ax is None:\n        fig = plt.figure(figsize=(3, 3))\n        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n    assert not np.any(\n        np.isnan(module_or_view.xyzr[0][:, :3])\n    ), \"missing xyz coordinates.\"\n    if \"x\" not in module_or_view.nodes.columns:\n        module_or_view.compute_compartment_centers()\n\n    for idx, xyzr in zip(module_or_view._branches_in_view, module_or_view.xyzr):\n        locs = xyzr[:, :3]\n        if locs.shape[0] == 1:  # assume spherical comp\n            radius = xyzr[:, -1]\n            center = xyzr[0, :3]\n            if len(dims) == 3:\n                xyz = create_sphere_mesh(radius, resolution)\n                ax = plot_mesh(\n                    xyz,\n                    np.array([0, 0, 1]),\n                    center,\n                    np.array(dims),\n                    ax,\n                    color=color,\n                    **kwargs,\n                )\n            else:\n                ax.add_artist(plt.Circle(locs[0, dims], radius, color=color))\n        else:\n            lens = np.sqrt(np.nansum(np.diff(locs, axis=0) ** 2, axis=1))\n            lens = np.cumsum([0] + lens.tolist())\n            comp_ends = v_interp(\n                np.linspace(0, lens[-1], module_or_view.ncomp + 1), lens, locs\n            ).T\n            axes = np.diff(comp_ends, axis=0)\n            cylinder_lens = np.sqrt(np.sum(axes**2, axis=1))\n\n            branch_df = module_or_view.nodes[\n                module_or_view.nodes[\"global_branch_index\"] == idx\n            ]\n            for l, axis, (i, comp) in zip(cylinder_lens, axes, branch_df.iterrows()):\n                center = comp[[\"x\", \"y\", \"z\"]]\n                radius = comp[\"radius\"]\n                length = comp[\"length\"] if true_comp_length else l\n                xyz = create_cylinder_mesh(length, radius, resolution)\n                ax = plot_mesh(\n                    xyz,\n                    axis,\n                    center,\n                    np.array(dims),\n                    ax,\n                    color=color,\n                    **kwargs,\n                )\n    return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_graph","title":"plot_graph(xyzr, dims=(0, 1), color='k', ax=None, type='line', **kwargs)","text":"

Plot morphology.

Parameters:

Name Type Description Default xyzr ndarray

The coordinates of the morphology.

required dims Tuple[int]

Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two or three of them.

(0, 1) color str

The color for all branches.

'k' ax Optional[Axes]

The matplotlib axis to plot on.

None type str

Either line or scatter.

'line' kwargs

The plot kwargs for plt.plot or plt.scatter.

{} Source code in jaxley/utils/plot_utils.py
def plot_graph(\n    xyzr: ndarray,\n    dims: Tuple[int] = (0, 1),\n    color: str = \"k\",\n    ax: Optional[Axes] = None,\n    type: str = \"line\",\n    **kwargs,\n) -> Axes:\n    \"\"\"Plot morphology.\n\n    Args:\n        xyzr: The coordinates of the morphology.\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two or three of them.\n        color: The color for all branches.\n        ax: The matplotlib axis to plot on.\n        type: Either `line` or `scatter`.\n        kwargs: The plot kwargs for plt.plot or plt.scatter.\n    \"\"\"\n\n    if ax is None:\n        fig = plt.figure(figsize=(3, 3))\n        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n    for coords_of_branch in xyzr:\n        points = coords_of_branch[:, dims].T\n\n        if \"line\" in type.lower():\n            _ = ax.plot(*points, color=color, **kwargs)\n        elif \"scatter\" in type.lower():\n            _ = ax.scatter(*points, color=color, **kwargs)\n        else:\n            raise NotImplementedError\n\n    return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_mesh","title":"plot_mesh(mesh_points, orientation, center, dims, ax=None, **kwargs)","text":"

Plot the 2D projection of a volume mesh on a cardinal plane.

Project the projection of a cylinder that is oriented in 3D space. - Create cylinder mesh - rotate cylinder mesh to orient it lengthwise along a given orientation vector. - move its center - project onto plane - compute outline of projected mesh. - fill area inside the outline

Parameters:

Name Type Description Default mesh_points ndarray

coordinates of the xyz mesh that define the volume

required orientation ndarray

orientation vector. The cylinder will be oriented along this vector.

required center ndarray

The x,y,z coordinates of the center of the cylinder.

required dims Tuple[int]

The dimensions to plot / to project the cylinder onto,

required ax Axes

The matplotlib axis to plot on.

None

Returns:

Type Description Axes

Plot of the cylinder projection.

Source code in jaxley/utils/plot_utils.py
def plot_mesh(\n    mesh_points: ndarray,\n    orientation: ndarray,\n    center: ndarray,\n    dims: Tuple[int],\n    ax: Axes = None,\n    **kwargs,\n) -> Axes:\n    \"\"\"Plot the 2D projection of a volume mesh on a cardinal plane.\n\n    Project the projection of a cylinder that is oriented in 3D space.\n    - Create cylinder mesh\n    - rotate cylinder mesh to orient it lengthwise along a given orientation vector.\n    - move its center\n    - project onto plane\n    - compute outline of projected mesh.\n    - fill area inside the outline\n\n    Args:\n        mesh_points: coordinates of the xyz mesh that define the volume\n        orientation: orientation vector. The cylinder will be oriented along this vector.\n        center: The x,y,z coordinates of the center of the cylinder.\n        dims: The dimensions to plot / to project the cylinder onto,\n        i.e. [0,1] xy-plane or [0,1,2] for 3D.\n        ax: The matplotlib axis to plot on.\n\n    Returns:\n        Plot of the cylinder projection.\n    \"\"\"\n    if ax is None:\n        fig = plt.figure(figsize=(3, 3))\n        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n    # Normalize axis vector\n    orientation = np.array(orientation)\n    orientation = orientation / np.linalg.norm(orientation)\n\n    # Create a rotation matrix to align the cylinder with the given axis\n    z_axis = np.array([0, 0, 1])\n    rotation_axis = np.cross(z_axis, orientation)\n    rotation_angle = np.arccos(np.dot(z_axis, orientation))\n\n    if np.allclose(rotation_axis, 0):\n        rotation_matrix = np.eye(3)\n    else:\n        rotation_matrix = compute_rotation_matrix(rotation_axis, rotation_angle)\n\n    # Rotate mesh\n    x_mesh, y_mesh, z_mesh = mesh_points\n    rotated_mesh_points = np.dot(\n        rotation_matrix,\n        np.array([x_mesh.flatten(), y_mesh.flatten(), z_mesh.flatten()]),\n    )\n    rotated_mesh_points = rotated_mesh_points.reshape(3, -1)\n\n    # project onto plane and move\n    rotated_mesh_points = rotated_mesh_points[dims]\n    rotated_mesh_points += np.array(center)[dims, np.newaxis]\n\n    if len(dims) < 3:\n        # get outline of cylinder mesh\n        mesh_outline = extract_outline(rotated_mesh_points.T).T\n        ax.fill(*mesh_outline.reshape(mesh_outline.shape[0], -1), **kwargs)\n    else:\n        # plot 3d mesh\n        ax.plot_surface(*rotated_mesh_points.reshape(*mesh_points.shape), **kwargs)\n    return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_morph","title":"plot_morph(module_or_view, dims=(0, 1), color='k', ax=None, resolution=100, **kwargs)","text":"

Plot the detailed morphology.

Plots the traced morphology it was traced. That means at every point that was traced a disc of radius r is plotted. The outline of the discs are then connected to form the morphology. This means every trace segement can be represented by a cone frustum. To prevent breaks in the morphology, each segement is connected with a ball joint.

Parameters:

Name Type Description Default module_or_view Union[Module, View]

The module or view to plot.

required dims Tuple[int]

The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.

(0, 1) color str

The color for all branches

'k' ax Optional[Axes]

The matplotlib axis to plot on.

None kwargs

The plot kwargs for plt.fill.

{} resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description Axes

Plot of the detailed morphology.

Source code in jaxley/utils/plot_utils.py
def plot_morph(\n    module_or_view: Union[\"jx.Module\", \"jx.View\"],\n    dims: Tuple[int] = (0, 1),\n    color: str = \"k\",\n    ax: Optional[Axes] = None,\n    resolution: int = 100,\n    **kwargs,\n) -> Axes:\n    \"\"\"Plot the detailed morphology.\n\n    Plots the traced morphology it was traced. That means at every point that was\n    traced a disc of radius `r` is plotted. The outline of the discs are then\n    connected to form the morphology. This means every trace segement can be\n    represented by a cone frustum. To prevent breaks in the morphology, each\n    segement is connected with a ball joint.\n\n    Args:\n        module_or_view: The module or view to plot.\n        dims: The dimensions to plot / to project the cylinder onto,\n            i.e. [0,1] xy-plane or [0,1,2] for 3D.\n        color: The color for all branches\n        ax: The matplotlib axis to plot on.\n        kwargs: The plot kwargs for plt.fill.\n\n        resolution: defines the resolution of the mesh.\n            If too low (typically <10), can result in errors.\n            Useful too have a simpler mesh for plotting.\n\n    Returns:\n        Plot of the detailed morphology.\"\"\"\n    if ax is None:\n        fig = plt.figure(figsize=(3, 3))\n        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n    if len(dims) == 3:\n        warn(\n            \"rendering large morphologies in 3D can take a while. Consider projecting to 2D instead.\"\n        )\n\n    assert not np.any(\n        np.isnan(module_or_view.xyzr[0][:, :3])\n    ), \"missing xyz coordinates.\"\n\n    for xyzr in module_or_view.xyzr:\n        if len(xyzr) > 1:\n            for xyzr1, xyzr2 in zip(xyzr[1:, :], xyzr[:-1, :]):\n                dxyz = xyzr2[:3] - xyzr1[:3]\n                length = np.sqrt(np.sum(dxyz**2))\n                points = create_cone_frustum_mesh(\n                    length,\n                    xyzr1[-1],\n                    xyzr2[-1],\n                    bottom_dome=True,\n                    top_dome=True,\n                    resolution=resolution,\n                )\n                plot_mesh(\n                    points,\n                    dxyz,\n                    xyzr1[:3],\n                    np.array(dims),\n                    color=color,\n                    ax=ax,\n                    **kwargs,\n                )\n        else:\n            points = create_cone_frustum_mesh(\n                0,\n                xyzr[:, -1],\n                xyzr[:, -1],\n                bottom_dome=True,\n                top_dome=True,\n                resolution=resolution,\n            )\n            plot_mesh(\n                points,\n                np.ones(3),\n                xyzr[0, :3],\n                dims=np.array(dims),\n                color=color,\n                ax=ax,\n                **kwargs,\n            )\n\n    return ax\n
"},{"location":"reference/utils/#jaxley.utils.jax_utils.nested_checkpoint_scan","title":"nested_checkpoint_scan(f, init, xs, length=None, *, nested_lengths, scan_fn=jax.lax.scan, checkpoint_fn=jax.checkpoint)","text":"

A version of lax.scan that supports recursive gradient checkpointing.

Code taken from: https://github.com/google/jax/issues/2139

The interface of nested_checkpoint_scan exactly matches lax.scan, except for the required nested_lengths argument.

The key feature of nested_checkpoint_scan is that gradient calculations require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested scans, which it achieves by re-evaluating the forward pass len(nested_lengths) - 1 times.

nested_checkpoint_scan reduces to lax.scan when nested_lengths has a single element.

Parameters:

Name Type Description Default f Callable[[Carry, Dict[str, ndarray]], Tuple[Carry, Output]]

function to scan over.

required init Carry

initial value.

required xs Dict[str, ndarray]

scanned over values.

required length Optional[int]

leading length of all dimensions

None nested_lengths Sequence[int]

required list of lengths to scan over for each level of checkpointing. The product of nested_lengths must match length (if provided) and the size of the leading axis for all arrays in xs.

required scan_fn

function matching the API of lax.scan

scan checkpoint_fn Callable[[Func], Func]

function matching the API of jax.checkpoint.

checkpoint Source code in jaxley/utils/jax_utils.py
def nested_checkpoint_scan(\n    f: Callable[[Carry, Dict[str, jnp.ndarray]], Tuple[Carry, Output]],\n    init: Carry,\n    xs: Dict[str, jnp.ndarray],\n    length: Optional[int] = None,\n    *,\n    nested_lengths: Sequence[int],\n    scan_fn=jax.lax.scan,\n    checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,\n):\n    \"\"\"A version of lax.scan that supports recursive gradient checkpointing.\n\n    Code taken from: https://github.com/google/jax/issues/2139\n\n    The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for\n    the required `nested_lengths` argument.\n\n    The key feature of `nested_checkpoint_scan` is that gradient calculations\n    require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested\n    scans, which it achieves by re-evaluating the forward pass\n    `len(nested_lengths) - 1` times.\n\n    `nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a\n    single element.\n\n    Args:\n        f: function to scan over.\n        init: initial value.\n        xs: scanned over values.\n        length: leading length of all dimensions\n        nested_lengths: required list of lengths to scan over for each level of\n            checkpointing. The product of nested_lengths must match length (if\n            provided) and the size of the leading axis for all arrays in ``xs``.\n        scan_fn: function matching the API of lax.scan\n        checkpoint_fn: function matching the API of jax.checkpoint.\n    \"\"\"\n    if length is not None and length != math.prod(nested_lengths):\n        raise ValueError(f\"inconsistent {length=} and {nested_lengths=}\")\n\n    def nested_reshape(x):\n        x = jnp.asarray(x)\n        new_shape = tuple(nested_lengths) + x.shape[1:]\n        return x.reshape(new_shape)\n\n    sub_xs = jax.tree_util.tree_map(nested_reshape, xs)\n    return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn, checkpoint_fn)\n
"},{"location":"reference/utils/#jaxley.utils.syn_utils.gather_synapes","title":"gather_synapes(number_of_compartments, post_syn_comp_inds, current_each_synapse_voltage_term, current_each_synapse_constant_term)","text":"

Compute current at the post synapse.

All this does it that it sums the synaptic currents that come into a particular compartment. It returns an array of as many elements as there are compartments.

Source code in jaxley/utils/syn_utils.py
def gather_synapes(\n    number_of_compartments: jnp.ndarray,\n    post_syn_comp_inds: np.ndarray,\n    current_each_synapse_voltage_term: jnp.ndarray,\n    current_each_synapse_constant_term: jnp.ndarray,\n) -> Tuple[jnp.ndarray, jnp.ndarray]:\n    \"\"\"Compute current at the post synapse.\n\n    All this does it that it sums the synaptic currents that come into a particular\n    compartment. It returns an array of as many elements as there are compartments.\n    \"\"\"\n    incoming_currents_voltages = jnp.zeros((number_of_compartments,))\n    incoming_currents_contant = jnp.zeros((number_of_compartments,))\n\n    dnums = ScatterDimensionNumbers(\n        update_window_dims=(),\n        inserted_window_dims=(0,),\n        scatter_dims_to_operand_dims=(0,),\n    )\n    incoming_currents_voltages = scatter_add(\n        incoming_currents_voltages,\n        post_syn_comp_inds[:, None],\n        current_each_synapse_voltage_term,\n        dnums,\n    )\n    incoming_currents_contant = scatter_add(\n        incoming_currents_contant,\n        post_syn_comp_inds[:, None],\n        current_each_synapse_constant_term,\n        dnums,\n    )\n    return incoming_currents_voltages, incoming_currents_contant\n
"},{"location":"tutorial/00_jaxley_api/","title":"Key concepts in Jaxley","text":"

In this tutorial, we will introduce you to the basic concepts of Jaxley. You will learn about:

  • Modules (e.g., Cell, Network,\u2026)
    • nodes
    • edges
  • Views
    • Groups
  • Channels
  • Synapses

Here is a code snippet which you will learn to understand in this tutorial:

import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n\n# Assembling different Modules into a Network\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=1)\ncell = jx.Cell(branch, parents=[-1, 0, 0])\nnet = jx.Network([cell]*3)\n\n# Navigating and inspecting the Modules using Views\ncell0 = net.cell(0)\ncell0.nodes\n\n# How to group together parts of Modules\nnet.cell(1).add_to_group(\"cell1\")\n\n# inserting channels in the membrane\nwith net.cell(0) as cell0:\n    cell0.insert(Na())\n    cell0.insert(K())\n\n# connecting two cells using a Synapse\npre_comp = cell0.branch(1).comp(0)\npost_comp = net.cell1.branch(0).comp(0)\n\nconnect(pre_comp, post_comp)\n

First, we import the relevant libraries:

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\nimport matplotlib.pyplot as plt\nimport numpy as np\n
"},{"location":"tutorial/00_jaxley_api/#modules","title":"Modules","text":"

In Jaxley, we heavily rely on the concept of Modules to build biophyiscal models of neural systems at various scales. Jaxley implements four types of Modules: - Compartment - Branch - Cell - Network

Modules can be connected together to build increasingly detailed and complex models. Compartment -> Branch -> Cell -> Network.

Compartments are the atoms of biophysical models in Jaxley. All mechanisms and synaptic connections live on the level of Compartments and can already be simulated using jx.integrate on their own. Everything you do in Jaxley starts with a Compartment.

comp = jx.Compartment() # single compartment model.\n

Mutliple Compartments can be connected together to form longer, linear cables, which we call Branches and are equivalent to sections in NEURON.

ncomp = 4\nbranch = jx.Branch([comp] * ncomp)\n

In order to construct cell morphologies in Jaxley, multiple Branches can to be connected together as a Cell:

# -1 indicates that the first branch has no parent branch.\n# The other two branches both have the 0-eth branch as their parent.\nparents = [-1, 0, 0]\ncell = jx.Cell([branch] * len(parents), parents)\n

Finally, several Cells can be grouped together to form a Network, which can than be connected together using Synpases.

ncells = 2\nnet = jx.Network([cell]*ncells)\n\nnet.shape # shows you the num_cells, num_branches, num_comps\n
(2, 6, 24)\n

Every module tracks information about its current state and parameters in two Dataframes called nodes and edges. nodes contains all the information that we associate with compartments in the model (each row corresponds to one compartment) and edges tracks all the information relevant to synapses.

This means that you can easily keep track of the current state of your Module and how it changes at all times.

net.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 0 2 11 0 12 1 0 0 10.0 1.0 5000.0 1.0 -70.0 1 3 12 0 13 1 0 1 10.0 1.0 5000.0 1.0 -70.0 1 3 13 0 14 1 0 2 10.0 1.0 5000.0 1.0 -70.0 1 3 14 0 15 1 0 3 10.0 1.0 5000.0 1.0 -70.0 1 3 15 0 16 1 1 0 10.0 1.0 5000.0 1.0 -70.0 1 4 16 0 17 1 1 1 10.0 1.0 5000.0 1.0 -70.0 1 4 17 0 18 1 1 2 10.0 1.0 5000.0 1.0 -70.0 1 4 18 0 19 1 1 3 10.0 1.0 5000.0 1.0 -70.0 1 4 19 0 20 1 2 0 10.0 1.0 5000.0 1.0 -70.0 1 5 20 0 21 1 2 1 10.0 1.0 5000.0 1.0 -70.0 1 5 21 0 22 1 2 2 10.0 1.0 5000.0 1.0 -70.0 1 5 22 0 23 1 2 3 10.0 1.0 5000.0 1.0 -70.0 1 5 23 0
net.edges.head() # this is currently empty since we have not made any connections yet\n
global_edge_index pre_global_comp_index post_global_comp_index pre_locs post_locs type type_ind"},{"location":"tutorial/00_jaxley_api/#views","title":"Views","text":"

Since these Modules can become very complex, Jaxley utilizes so called Views to make working with Modules easy and intuitive.

The simplest way to navigate Modules is by navigating them via the hierachy that we introduced above. A View is what you get when you index into the module. For example, for a Network:

net.cell(0)\n
View with 0 different channels. Use `.nodes` for details.\n

Views behave very similarly to Modules, i.e. the cell(0) (the 0th cell of the network) behaves like the cell we instantiated earlier. As such, cell(0) also has a nodes attribute, which keeps track of it\u2019s part of the network:

net.cell(0).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 0 2 11 0

Let\u2019s use Views to visualize only parts of the Network. Before we do that, we create x, y, and z coordinates for the Network:

# Compute xyz coordinates of the cells.\nnet.compute_xyz()\n\n# Move cells (since they are placed on top of each other by default).\nnet.cell(0).move(y=30)\n

We can now visualize the entire net (i.e., the entire Module) with the .vis() method\u2026

# We can use the vis function to visualize Modules.\nfig, ax = plt.subplots(1, 1, figsize=(3,3))\nnet.vis(ax=ax)\n
<Axes: >\n

\u2026but we can also create a View to visualize only parts of the net:

# ... and Views\nfig, ax = plt.subplots(1,1, figsize=(3,3))\nnet.cell(0).vis(ax=ax, color=\"blue\") # View of the 0th cell of the network\nnet.cell(1).vis(ax=ax, color=\"red\") # View of the 1st cell of the network\n\nnet.cell(0).branch(0).vis(ax=ax, color=\"green\") # View of the 1st branch of the 0th cell of the network\nnet.cell(1).branch(1).comp(1).vis(ax=ax, color=\"black\", type=\"line\") # View of the 0th comp of the 1st branch of the 0th cell of the network\n
<Axes: >\n

"},{"location":"tutorial/00_jaxley_api/#how-to-create-views","title":"How to create Views","text":"

Above, we used net.cell(0) to generate a View of the 0-eth cell. Jaxley supports many ways of performing such indexing:

# several types of indices are supported (lists, ranges, ...)\nnet.cell([0,1]).branch(\"all\").comp(0)  # View of all 0th comps of all branches of cell 0 and 1\n\nbranch.loc(0.1)  # Equivalent to `NEURON`s `loc`. Assumes branches are continous from 0-1.\n\nnet[0,0,0]  # Modules/Views can also be lazily indexed\n\ncell0 = net.cell(0)  # Views can be assigned to variables and only track the parts of the Module they belong to\ncell0.branch(1).comp(0)  # Views can be continuely indexed\n
View with 0 different channels. Use `.nodes` for details.\n
cell0.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 0 2 11 0
net.shape\n
(2, 6, 24)\n

Note: In case you need even more flexibility in how you select parts of a Module, Jaxley provides a select method, to give full control over the exact parts of the nodes and edges that are part of a View. On examples of how this can be used, see the tutorial on advanced indexing.

You can also iterate over networks, cells, and branches:

# We set the radiuses to random values...\nradiuses = np.random.rand((24))\nnet.set(\"radius\", radiuses)\n\n# ...and then we set the length to 100.0 um if the radius is >0.5.\nfor cell in net:\n    for branch in cell:\n        for comp in branch:\n            if comp.nodes.iloc[0][\"radius\"] > 0.5:\n                comp.set(\"length\", 100.0)\n\n# Show the first five compartments:\nnet.nodes[[\"radius\", \"length\"]][:5]\n
radius length 0 0.537066 100.0 1 0.050138 10.0 2 0.913129 100.0 3 0.874596 100.0 4 0.048903 10.0

Finally, you can also use Views in a context manager:

with net.cell(0).branch(0) as branch0:\n    branch0.set(\"radius\", 2.0)\n    branch0.set(\"length\", 2.5)\n\n# Show the first five compartments.\nnet.nodes[[\"radius\", \"length\"]][:5]\n
radius length 0 2.000000 2.5 1 2.000000 2.5 2 2.000000 2.5 3 2.000000 2.5 4 0.048903 10.0"},{"location":"tutorial/00_jaxley_api/#channels","title":"Channels","text":"

The Modules that we have created above will not do anything interesting, since by default Jaxley initializes them without any mechanisms in the membrane. To change this, we have to insert channels into the membrane. For this purpose Jaxley implements Channels that can be inserted into any compartment using the insert method of a Module or a View:

# insert a Leak channel into all compartments in the Module.\nnet.insert(Leak())\nnet.nodes.head() # Channel parameters are now also added to `nodes`.\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param Leak Leak_gLeak Leak_eLeak 0 0 0 0 2.5 2.000000 5000.0 1.0 -70.0 0 0 0 0 True 0.0001 -70.0 1 0 0 1 2.5 2.000000 5000.0 1.0 -70.0 0 0 1 0 True 0.0001 -70.0 2 0 0 2 2.5 2.000000 5000.0 1.0 -70.0 0 0 2 0 True 0.0001 -70.0 3 0 0 3 2.5 2.000000 5000.0 1.0 -70.0 0 0 3 0 True 0.0001 -70.0 4 0 1 0 10.0 0.048903 5000.0 1.0 -70.0 0 1 4 0 True 0.0001 -70.0

This is also were Views come in handy, as it allows to easily target the insertion of channels to specific compartments.

# inserting several channels into parts of the network\nwith net.cell(0) as cell0:\n    cell0.insert(Na())\n    cell0.insert(K())\n\n# # The above is equivalent to:\n# net.cell(0).insert(Na())\n# net.cell(0).insert(K())\n\n# K and Na channels were only insert into cell 0\nnet.cell(\"all\").branch(0).comp(0).nodes[[\"global_cell_index\", \"Na\", \"K\", \"Leak\"]]\n
global_cell_index Na K Leak 0 0 True True True 12 1 False False True"},{"location":"tutorial/00_jaxley_api/#synapses","title":"Synapses","text":"

To connect different cells together, Jaxley implements a connect method, that can be used to couple 2 compartments together using a Synapse. Synapses in Jaxley work only on the compartment level, that means to be able to connect two cells, you need to specify the exact compartments on a given cell to make the connections between. Below is an example of this:

# connecting two cells using a Synapse\npre_comp = cell0.branch(1).comp(0)\npost_comp = net.cell(1).branch(0).comp(0)\n\nconnect(pre_comp, post_comp, IonotropicSynapse())\n\nnet.edges\n
global_edge_index pre_global_comp_index post_global_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 4 12 IonotropicSynapse 0 0.125 0.125 0.0001 0.0 0.025 0.2 0

As you can see above, now the edges dataframe is also updated with the information of the newly added synapse.

Congrats! You should now have an intuitive understand of how to use Jaxley\u2019s API to construct, navigate and manipulate neuron models.

"},{"location":"tutorial/01_morph_neurons/","title":"Basics of Jaxley","text":"

In this tutorial, you will learn how to:

  • build your first morphologically detailed cell or read it from SWC
  • stimulate the cell
  • record from the cell
  • visualize cells
  • run your first simulation

Here is a code snippet which you will learn to understand in this tutorial:

import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nimport matplotlib.pyplot as plt\n\n\n# Build the cell.\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])\n\n# Insert channels.\ncell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n\n# Change parameters.\ncell.set(\"axial_resistivity\", 200.0)\n\n# Visualize the morphology.\ncell.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\ncell.vis(ax=ax)\n\n# Stimulate.\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=10.0)\ncell.branch(0).loc(0.0).stimulate(current)\n\n# Record.\ncell.branch(0).loc(0.0).record(\"v\")\n\n# Simulate and plot.\nv = jx.integrate(cell, delta_t=0.025)\nplt.plot(v.T)\n

First, we import the relevant libraries:

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n

We will now build our first cell in Jaxley. You have two options to do this: you can either build a cell bottom-up by defining the morphology yourselve, or you can load cells from SWC files.

"},{"location":"tutorial/01_morph_neurons/#define-the-cell-from-scratch","title":"Define the cell from scratch","text":"

To define a cell from scratch you first have to define a single compartment and then assemble those compartments into a branch:

comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\n

Next, we can assemble branches into a cell. To do so, we have to define for each branch what its parent branch is. A -1 entry means that this branch does not have a parent.

parents = jnp.asarray([-1, 0, 0, 1, 1])\ncell = jx.Cell(branch, parents=parents)\n

To learn more about Compartments, Branches, and Cells, see this tutorial.

"},{"location":"tutorial/01_morph_neurons/#read-the-cell-from-an-swc-file","title":"Read the cell from an SWC file","text":"

Alternatively, you could also load cells from SWC with

cell = jx.read_swc(fname, ncomp=4)

Details on handling SWC files can be found in this tutorial.

"},{"location":"tutorial/01_morph_neurons/#visualize-the-cells","title":"Visualize the cells","text":"

Cells can be visualized as follows:

cell.compute_xyz()  # Only needed for visualization.\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, color=\"k\")\n

"},{"location":"tutorial/01_morph_neurons/#insert-mechanisms","title":"Insert mechanisms","text":"

Currently, the cell does not contain any kind of ion channel (not even a leak). We can fix this by inserting a leak channel into the entire cell, and by inserting sodium and potassium into the zero-eth branch.

cell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n

Once the cell is created, we can inspect its .nodes attribute which lists all properties of the cell:

cell.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index ... Na Na_gNa eNa vt Na_m Na_h K K_gK eK K_n 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 ... True 0.05 50.0 -60.0 0.2 0.2 True 0.005 -90.0 0.2 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 ... True 0.05 50.0 -60.0 0.2 0.2 True 0.005 -90.0 0.2 2 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 3 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 4 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 5 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 6 0 3 0 10.0 1.0 5000.0 1.0 -70.0 0 3 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 7 0 3 1 10.0 1.0 5000.0 1.0 -70.0 0 3 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 8 0 4 0 10.0 1.0 5000.0 1.0 -70.0 0 4 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 9 0 4 1 10.0 1.0 5000.0 1.0 -70.0 0 4 ... False NaN NaN NaN NaN NaN False NaN NaN NaN

10 rows \u00d7 25 columns

Note that Jaxley uses the same units as the NEURON simulator, which are listed here.

You can also inspect just parts of the cell, for example its 1st branch:

cell.branch(1).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v Leak Leak_gLeak ... Na_m Na_h K K_gK eK K_n global_cell_index global_branch_index global_comp_index controlled_by_param 2 0 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 2 1 3 0 0 1 10.0 1.0 5000.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 3 1

2 rows \u00d7 25 columns

The easiest way to know which branch is the 1st branch (or, e.g., the zero-eth compartment of the 1st branch) is to plot it in a different color:

fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, color=\"k\")\n_ = cell.branch(1).vis(ax=ax, color=\"r\")\n_ = cell.branch(1).comp(1).vis(ax=ax, color=\"b\")\n

More background and features on indexing as cell.branch(0) is in this tutorial.

"},{"location":"tutorial/01_morph_neurons/#change-parameters-of-the-cell","title":"Change parameters of the cell","text":"

You can change properties of the cell with the .set() method:

cell.branch(1).set(\"axial_resistivity\", 200.0)\n

And we can again inspect the .nodes to make sure that the axial resistivity indeed changed:

cell.branch(1).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v Leak Leak_gLeak ... Na_m Na_h K K_gK eK K_n global_cell_index global_branch_index global_comp_index controlled_by_param 2 0 0 0 10.0 1.0 200.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 2 1 3 0 0 1 10.0 1.0 200.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 3 1

2 rows \u00d7 25 columns

In a similar way, you can modify channel properties or initial states (units are again here):

cell.branch(0).set(\"K_gK\", 0.01)  # modify potassium conductance.\ncell.set(\"v\", -65.0)  # modify initial voltage.\n
"},{"location":"tutorial/01_morph_neurons/#stimulate-the-cell","title":"Stimulate the cell","text":"

We next stimulate one of the compartments with a step current. For this, we first define the step current (units are again here):

dt = 0.025\nt_max = 10.0\ntime_vec = np.arange(0, t_max+dt, dt)\ncurrent = jx.step_current(i_delay=1.0, i_dur=2.0, i_amp=0.08, delta_t=dt, t_max=t_max)\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = plt.plot(time_vec, current)\n

We then stimulate one of the compartments of the cell with this step current:

cell.delete_stimuli()\ncell.branch(0).loc(0.0).stimulate(current)\n
Added 1 external_states. See `.externals` for details.\n
"},{"location":"tutorial/01_morph_neurons/#define-recordings","title":"Define recordings","text":"

Next, you have to define where to record the voltage. In this case, we will record the voltage at two locations:

cell.delete_recordings()\ncell.branch(0).loc(0.0).record(\"v\")\ncell.branch(3).loc(1.0).record(\"v\")\n
Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n

We can again visualize these locations to understand where we inserted recordings:

fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax)\n_ = cell.branch(0).loc(0.0).vis(ax=ax, color=\"b\")\n_ = cell.branch(3).loc(1.0).vis(ax=ax, color=\"g\")\n

"},{"location":"tutorial/01_morph_neurons/#simulate-the-cell-response","title":"Simulate the cell response","text":"

Having set up the cell, inserted stimuli and recordings, we are now ready to run a simulation with jx.integrate:

voltages = jx.integrate(cell, delta_t=dt)\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (2, 402)\n

The jx.integrate function returns an array of shape (num_recordings, num_timepoints). In our case, we inserted 2 recordings and we simulated for 10ms at a 0.025 time step, which leads to 402 time steps.

We can now visualize the voltage response:

fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(voltages[0], c=\"b\")\n_ = ax.plot(voltages[1], c=\"orange\")\n

At the location of the first recording (in blue) the cell spiked, whereas at the second recording, it did not. This makes sense because we only inserted sodium and potassium channels into the first branch, but not in the entire cell.

Congrats! You have just run your first morphologically detailed neuron simulation in Jaxley. We suggest to continue by learning how to build networks. If you are only interested in single cell simulations, you can directly jump to learning how to speed up simulations. If you want to simulate detailed morphologies from SWC files, checkout our tutorial on working with detailed morphologies.

"},{"location":"tutorial/02_small_network/","title":"Network simulations in Jaxley","text":"

In this tutorial, you will learn how to:

  • connect neurons into a network
  • visualize networks
  • use the .edges attribute to inspect and change synaptic parameters

Here is a code snippet which you will learn to understand in this tutorial:

import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\n\n\n# Define a network. `cell` is defined as in previous tutorial.\nnet = jx.Network([cell for _ in range(11)])\n\n# Define synapses.\nfully_connect(\n    net.cell(range(10)),\n    net.cell(10),\n    IonotropicSynapse(),\n)\n\n# Change synaptic parameters.\nnet.select(edges=[0, 1]).set(\"IonotropicSynapse_gS\", 0.1)  # nS\n\n# Visualize the network.\nnet.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\nnet.vis(ax=ax, detail=\"full\", layers=[10, 1])  # or `detail=\"point\"`.\n

In the previous tutorial, you learned how to build single cells with morphological detail, how to insert stimuli and recordings, and how to run a first simulation. In this tutorial, we will define networks of multiple cells and connect them with synapses. Let\u2019s get started:

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect, connect\n
"},{"location":"tutorial/02_small_network/#define-the-network","title":"Define the network","text":"

First, we define a cell as you saw in the previous tutorial.

comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n

We can assemble multiple cells into a network by using jx.Network, which takes a list of jx.Cells. Here, we assemble 11 cells into a network:

num_cells = 11\nnet = jx.Network([cell for _ in range(num_cells)])\n

At this point, we can already visualize this network:

net.compute_xyz()\nnet.rotate(180)\nnet.arrange_in_layers(layers=[10, 1], within_layer_offset=150, between_layer_offset=200)\n\nfig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\")\n

Note: you can use move_to to have more control over the location of cells, e.g.: network.cell(i).move_to(x=0, y=200).

As you can see, the neurons are not connected yet. Let\u2019s fix this by connecting neurons with synapses. We will build a network consisting of two layers: 10 neurons in the input layer and 1 neuron in the output layer.

We can use Jaxley\u2019s fully_connect method to connect these layers:

pre = net.cell(range(10))\npost = net.cell(10)\nfully_connect(pre, post, IonotropicSynapse())\n

Let\u2019s visualize this again:

fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\")\n

As you can see, the full_connect method inserted one synapse (in blue) from every neuron in the first layer to the output neuron. The fully_connect method builds this synapse from the zero-eth compartment and zero-eth branch of the presynaptic neuron onto a random branch of the postsynaptic neuron. If you want more control over the pre- and post-synaptic branches, you can use the connect method:

pre = net.cell(0).branch(5).loc(1.0)\npost = net.cell(10).branch(0).loc(0.0)\nconnect(pre, post, IonotropicSynapse())\n
fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\")\n

"},{"location":"tutorial/02_small_network/#inspecting-and-changing-synaptic-parameters","title":"Inspecting and changing synaptic parameters","text":"

You can inspect synaptic parameters via the .edges attribute:

net.edges\n
global_edge_index pre_global_comp_index post_global_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 0 287 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 1 1 28 289 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 2 2 56 289 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 3 3 84 301 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 4 4 112 281 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 5 5 140 295 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 6 6 168 289 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 7 7 196 290 IonotropicSynapse 0 0.125 0.625 0.0001 0.0 0.025 0.2 0 8 8 224 303 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 9 9 252 280 IonotropicSynapse 0 0.125 0.125 0.0001 0.0 0.025 0.2 0 10 10 23 280 IonotropicSynapse 0 0.875 0.125 0.0001 0.0 0.025 0.2 0

To modify a parameter of all synapses you can again use .set():

net.set(\"IonotropicSynapse_gS\", 0.0003)  # nS\n

To modify individual syanptic parameters, use the .select() method. Below, we change the values of the first two synapses:

net.select(edges=[0, 1]).set(\"IonotropicSynapse_gS\", 0.0004)  # nS\n

For more details on how to flexibly set synaptic parameters (e.g., by cell type, or by pre-synaptic cell index,\u2026), see this tutorial.

"},{"location":"tutorial/02_small_network/#stimulating-recording-and-simulating-the-network","title":"Stimulating, recording, and simulating the network","text":"

We will now set up a simulation of the network. This works exactly as it does for single neurons:

# Stimulus.\ni_delay = 3.0  # ms\ni_amp = 0.05  # nA\ni_dur = 2.0  # ms\n\n# Duration and step size.\ndt = 0.025  # ms\nt_max = 50.0  # ms\n
time_vec = jnp.arange(0.0, t_max + dt, dt)\n

As a simple example, we insert sodium, potassium, and leak into every compartment of every cell of the network.

net.insert(Na())\nnet.insert(K())\nnet.insert(Leak())\n

We stimulate every neuron in the input layer and record the voltage from the output neuron:

current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)\nnet.delete_stimuli()\nfor stim_ind in range(10):\n    net.cell(stim_ind).branch(0).loc(0.0).stimulate(current)\n\nnet.delete_recordings()\nnet.cell(10).branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n

Finally, we can again run the network simulation and plot the result:

s = jx.integrate(net, delta_t=dt)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T)\n

That\u2019s it! You now know how to simulate networks of morphologically detailed neurons. We recommend that you now have a look at how you can speed up your simulation. To learn more about handling synaptic parameters, we recommend to check out this tutorial.

"},{"location":"tutorial/04_jit_and_vmap/","title":"Speeding up simulations","text":"

In this tutorial, you will learn how to:

  • make parameter sweeps in Jaxley
  • use jit to compile your simulations and make them faster
  • use vmap to parallelize simulations on GPUs

Here is a code snippet which you will learn to understand in this tutorial:

from jax import jit, vmap\n\n\ncell = ...  # See tutorial on Basics of Jaxley.\n\ndef simulate(params):\n    param_state = None\n    param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n    param_state = cell.data_set(\"K_gK\", params[1], param_state)\n    return jx.integrate(cell, param_state=param_state, delta_t=0.025)\n\n# Define 100 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(100, 2))\n\n# Fast for-loops with jit compilation.\njitted_simulate = jit(simulate)\nvoltages = [jitted_simulate(params) for params in all_params]\n\n# Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate, in_axes=(0,))\nvoltages = vmapped_simulate(all_params)\n

In the previous tutorials, you learned how to build single cells or networks and how to change their parameters. In this tutorial, you will learn how to speed up such simulations by many orders of magnitude. This can be achieved in to ways:

  • by using JIT compilation
  • by using GPU parallelization

Let\u2019s get started!

"},{"location":"tutorial/04_jit_and_vmap/#using-gpu-or-cpu","title":"Using GPU or CPU","text":"

In Jaxley you can set whether you want to use gpu or cpu with the following lines at the beginning of your script:

from jax import config\nconfig.update(\"jax_platform_name\", \"cpu\")\n

JAX (and Jaxley) also allow to choose between float32 and float64. Especially on GPUs, float32 will be faster, but we have experienced stability issues when simulating morphologically detailed neurons with float32.

config.update(\"jax_enable_x64\", True)  # Set to false to use `float32`.\n

Next, we will import relevant libraries:

import matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit, vmap\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\n
"},{"location":"tutorial/04_jit_and_vmap/#building-the-cell-or-network","title":"Building the cell or network","text":"

We first build a cell (or network) in the same way as we showed in the previous tutorials:

dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n\ncell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n\ncell.delete_stimuli()\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=dt, t_max=t_max)\ncell.branch(0).loc(0.0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
"},{"location":"tutorial/04_jit_and_vmap/#parameter-sweeps","title":"Parameter sweeps","text":"

Assume you want to run the same cell with many different values for the sodium and potassium conductance, for example for genetic algorithms or for parameter sweeps. To do this efficiently in Jaxley, you have to use the data_set() method (in combination with jit and vmap, as shown later):

def simulate(params):\n    param_state = None\n    param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n    param_state = cell.data_set(\"K_gK\", params[1], param_state)\n    return jx.integrate(cell, param_state=param_state, delta_t=dt)\n

The .data_set() method takes three arguments:

1) the name of the parameter you want to set. Jaxley allows to set the following parameters: \u201cradius\u201d, \u201clength\u201d, \u201caxial_resistivity\u201d, as well as all parameters of channels and synapses. 2) the value of the parameter. 3) a param_state which is initialized as None and is modified by .data_set(). This has to be passed to jx.integrate().

Having done this, the simplest (but least efficient) way to perform the parameter sweep is to run a for-loop over many parameter sets:

# Define 5 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(5, 2))\n\nvoltages = jnp.asarray([simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (5, 1, 402)\n

The resulting voltages have shape (num_simulations, num_recordings, num_timesteps).

"},{"location":"tutorial/04_jit_and_vmap/#stimulus-sweeps","title":"Stimulus sweeps","text":"

In addition to running sweeps across multiple parameters, you can also run sweeeps across multiple stimuli (e.g. step current stimuli of different amplitudes. You can achieve this with the data_stimulate() method:

def simulate(i_amp):\n    current = jx.step_current(1.0, 1.0, i_amp, 0.025, 10.0)\n\n    data_stimuli = None\n    data_stimuli = cell.branch(0).comp(0).data_stimulate(current, data_stimuli)\n    return jx.integrate(cell, data_stimuli=data_stimuli)\n

"},{"location":"tutorial/04_jit_and_vmap/#speeding-up-for-loops-via-jit-compilation","title":"Speeding up for loops via jit compilation","text":"

We can speed up such parameter sweeps (or stimulus sweeps) with jit compilation. jit compilation will compile the simulation when it is run for the first time, such that every other simulation will be must faster. This can be achieved by defining a new function which uses JAX\u2019s jit():

jitted_simulate = jit(simulate)\n
# First run, will be slow.\nvoltages = jitted_simulate(all_params[0])\n
# More runs, will be much faster.\nvoltages = jnp.asarray([jitted_simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (5, 1, 402)\n

jit compilation can be up to 10k times faster, especially for small simulations with few compartments. For very large models, the gain obtained with jit will be much smaller (jit may even provide no speed up at all).

"},{"location":"tutorial/04_jit_and_vmap/#speeding-up-with-gpu-parallelization-via-vmap","title":"Speeding up with GPU parallelization via vmap","text":"

Another way to speed up parameter sweeps is with GPU parallelization. Parallelization in Jaxley can be achieved by using vmap of JAX. To do this, we first create a new function that handles multiple parameter sets directly:

# Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate)\n

We can then run this method on all parameter sets (all_params.shape == (100, 2)), and Jaxley will automatically parallelize across them. Of course, you will only get a speed-up if you have a GPU available and you specified gpu as device in the beginning of this tutorial.

voltages = vmapped_simulate(all_params)\n

GPU parallelization with vmap can give a large speed-up, which can easily be 2-3 orders of magnitude.

"},{"location":"tutorial/04_jit_and_vmap/#combining-jit-and-vmap","title":"Combining jit and vmap","text":"

Finally, you can also combine using jit and vmap. For example, you can run multiple batches of many parallel simulations. Each batch can be parallelized with vmap and simulating each batch can be compiled with jit:

jitted_vmapped_simulate = jit(vmap(simulate))\n
for batch in range(10):\n    all_params = jnp.asarray(np.random.rand(5, 2))\n    voltages_batch = jitted_vmapped_simulate(all_params)\n

That\u2019s all you have to know about jit and vmap! If you have worked through this and the previous tutorials, you should be ready to set up your first network simulations.

"},{"location":"tutorial/04_jit_and_vmap/#next-steps","title":"Next steps","text":"

If you want to learn more, we recommend you to read the tutorial on building channel and synapse models.

Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.

Finally, if you want to learn more about JAX, check out their tutorial on jit or their tutorial on vmap.

"},{"location":"tutorial/05_channel_and_synapse_models/","title":"Building ion channel models","text":"

In this tutorial, you will learn how to:

  • define your own ion channel models beyond the preconfigured channels in Jaxley

This tutorial assumes that you have already learned how to build basic simulations.

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\n

First, we define a cell as you saw in the previous tutorial:

comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n

You have also already learned how to insert preconfigured channels into Jaxley models:

cell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n

In this tutorial, we will show you how to build your own channel and synapse models.

"},{"location":"tutorial/05_channel_and_synapse_models/#your-own-channel","title":"Your own channel","text":"

Below is how you can define your own channel. We will go into detail about individual parts of the code in the next couple of cells.

import jax.numpy as jnp\nfrom jaxley.channels import Channel\nfrom jaxley.solver_gate import solve_gate_exponential\n\n\ndef exp_update_alpha(x, y):\n    return x / (jnp.exp(x / y) - 1.0)\n\nclass Potassium(Channel):\n    \"\"\"Potassium channel.\"\"\"\n\n    def __init__(self, name = None):\n        self.current_is_in_mA_per_cm2 = True\n        super().__init__(name)\n        self.channel_params = {\"gK_new\": 1e-4}\n        self.channel_states = {\"n_new\": 0.0}\n        self.current_name = \"i_K\"\n\n    def update_states(self, states, dt, v, params):\n        \"\"\"Update state.\"\"\"\n        ns = states[\"n_new\"]\n        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n        beta = 0.125 * jnp.exp(-(v + 65) / 80)\n        new_n = solve_gate_exponential(ns, dt, alpha, beta)\n        return {\"n_new\": new_n}\n\n    def compute_current(self, states, v, params):\n        \"\"\"Return current.\"\"\"\n        ns = states[\"n_new\"]\n        kd_conds = params[\"gK_new\"] * ns**4  # S/cm^2\n\n        e_kd = -77.0        \n        return kd_conds * (v - e_kd)\n\n    def init_state(self, states, v, params, delta_t):\n        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n        beta = 0.125 * jnp.exp(-(v + 65) / 80)\n        return {\"n_new\": alpha / (alpha + beta)}\n

Let\u2019s look at each part of this in detail.

The below is simply a helper function for the solver of the gate variables:

def exp_update_alpha(x, y):\n    return x / (jnp.exp(x / y) - 1.0)\n

Next, we define our channel as a class. It should inherit from the Channel class and define channel_params, channel_states, and current_name. You also need to set self.current_is_in_mA_per_cm2=True as the first line on your __init__() method. This is to acknowledge that your current is returned in mA/cm2 (not in uA/cm2, as would have been required in Jaxley versions 0.4.0 or older).

class Potassium(Channel):\n    \"\"\"Potassium channel.\"\"\"\n\n    def __init__(self, name=None):\n        self.current_is_in_mA_per_cm2 = True\n        super().__init__(name)\n        self.channel_params = {\"gK_new\": 1e-4}\n        self.channel_states = {\"n_new\": 0.0}\n        self.current_name = \"i_K\"\n

Next, we have the update_states() method, which updates the gating variables:

    def update_states(self, states, dt, v, params):\n

Every channel you define must have an update_states() method which takes exactly these five arguments (self, states, dt, v, params). The inputs states to the update_states method is a dictionary which contains all states that are updated (including states of other channels). v is a jnp.ndarray which contains the voltage of a single compartment (shape ()). Let\u2019s get the state of the potassium channel which we are building here:

ns = states[\"n_new\"]\n

Next, we update the state of the channel. In this example, we do this with exponential Euler, but you can implement any solver yourself:

alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\nbeta = 0.125 * jnp.exp(-(v + 65) / 80)\nnew_n = solve_gate_exponential(ns, dt, alpha, beta)\nreturn {\"n_new\": new_n}\n

A channel also needs a compute_current() method which returns the current throught the channel:

    def compute_current(self, states, v, params):\n        ns = states[\"n_new\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        kd_conds = params[\"gK_new\"] * ns**4  # S/cm^2\n\n        e_kd = -77.0        \n        current = kd_conds * (v - e_kd)\n        return current\n

Finally, the init_state() method can be implemented optionally. It can be used to automatically compute the initial state based on the voltage when cell.init_states() is run.

Alright, done! We can now insert this channel into any jx.Module such as our cell:

cell.insert(Potassium())\n
cell.delete_stimuli()\ncurrent = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0)\ncell.branch(0).comp(0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).comp(0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
s = jx.integrate(cell)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n

"},{"location":"tutorial/05_channel_and_synapse_models/#your-own-synapse","title":"Your own synapse","text":"

The parts below assume that you have already learned how to build network simulations in Jaxley.

Note that again, a synapse needs to have the two functions update_states and compute_current with all input arguments shown below.

The below is an example of how to define your own synapse model in Jaxley:

import jax.numpy as jnp\nfrom jaxley.synapses.synapse import Synapse\n\n\nclass TestSynapse(Synapse):\n    \"\"\"\n    Compute syanptic current and update syanpse state.\n    \"\"\"\n    def __init__(self, name = None):\n        super().__init__(name)\n        self.synapse_params = {\"gChol\": 0.001, \"eChol\": 0.0}\n        self.synapse_states = {\"s_chol\": 0.1}\n\n    def update_states(self, states, delta_t, pre_voltage, post_voltage, params):\n        \"\"\"Return updated synapse state and current.\"\"\"\n        s_inf = 1.0 / (1.0 + jnp.exp((-35.0 - pre_voltage) / 10.0))\n        exp_term = jnp.exp(-delta_t)\n        new_s = states[\"s_chol\"] * exp_term + s_inf * (1.0 - exp_term)\n        return {\"s_chol\": new_s}\n\n    def compute_current(self, states, pre_voltage, post_voltage, params):\n        g_syn = params[\"gChol\"] * states[\"s_chol\"]\n        return g_syn * (post_voltage - params[\"eChol\"])\n

As you can see above, synapses follow closely how channels are defined. The main difference is that the compute_current method takes two voltages: the pre-synaptic voltage (a jnp.ndarray of shape ()) and the post-synaptic voltage (a jnp.ndarray of shape ()).

net = jx.Network([cell for _ in range(3)])\n
from jaxley.connect import connect\n\npre = net.cell(0).branch(0).loc(0.0)\npost = net.cell(1).branch(0).loc(0.0)\nconnect(pre, post, TestSynapse())\n
net.cell(0).branch(0).loc(0.0).stimulate(jx.step_current(1.0, 2.0, 0.1, 0.025, 10.0))\nfor i in range(3):\n    net.cell(i).branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
s = jx.integrate(net)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n

That\u2019s it! You are now ready to build your own custom simulations and equip them with channel and synapse models!

This tutorial does not have an immediate follow-up tutorial. If you have not done so already, you can check out our tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.

"},{"location":"tutorial/06_groups/","title":"Defining groups","text":"

In this tutorial, you will learn how to:

  • define groups (aka sectionlists) to simplify iteractions with Jaxley

Here is a code snippet which you will learn to understand in this tutorial:

from jax import jit, vmap\n\n\nnet = ...  # See tutorial on Basics of Jaxley.\n\nnet.cell(0).add_to_group(\"fast_spiking\")\nnet.cell(1).add_to_group(\"slow_spiking\")\n\ndef simulate(params):\n    param_state = None\n    param_state = net.fast_spiking.data_set(\"HH_gNa\", params[0], param_state)\n    param_state = net.slow_spiking.data_set(\"HH_gNa\", params[1], param_state)\n    return jx.integrate(net, param_state=param_state)\n\n# Define sodium for fast and slow spiking neurons.\nparams = jnp.asarray([1.0, 0.1])\n\n# Run simulation.\nvoltages = simulate(params)\n

In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport time\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n

First, we define a network as you saw in the previous tutorial:

comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1])\nnetwork = jx.Network([cell for _ in range(3)])\n\npre = network.cell([0, 1])\npost = network.cell([2])\nfully_connect(pre, post, IonotropicSynapse())\n\nnetwork.insert(Na())\nnetwork.insert(K())\nnetwork.insert(Leak())\n
"},{"location":"tutorial/06_groups/#group-apical-dendrites","title":"Group: apical dendrites","text":"

Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:

for cell_ind in range(3):\n    network.cell(cell_ind).branch(1).add_to_group(\"apical\")\n    network.cell(cell_ind).branch(3).add_to_group(\"apical\")\n

After this, we can access network.apical as we previously accesses anything else:

network.apical.set(\"radius\", 0.3)\n
network.apical.view\n
View with 3 different channels. Use `.nodes` for details.\n
"},{"location":"tutorial/06_groups/#group-fast-spiking","title":"Group: fast spiking","text":"

Similarly, you could define a group of fast-spiking cells. Assume that the first and second cell are fast-spiking:

network.cell(0).add_to_group(\"fast_spiking\")\nnetwork.cell(1).add_to_group(\"fast_spiking\")\n
network.fast_spiking.set(\"Na_gNa\", 0.4)\n
network.fast_spiking.view\n
View with 3 different channels. Use `.nodes` for details.\n
"},{"location":"tutorial/06_groups/#groups-from-swc-files","title":"Groups from SWC files","text":"

If you are reading .swc morphologigies, you can automatically assign groups with

jx.read_swc(file_name, nseg=n, assign_groups=True).\n
After that, you can directly use cell.soma, cell.apical, cell.basal, or cell.axon.

"},{"location":"tutorial/06_groups/#how-groups-are-interpreted-by-make_trainable","title":"How groups are interpreted by .make_trainable()","text":"

If you make a parameter of a group trainable, then it will be treated as a single shared parameter for a given property:

network.fast_spiking.make_trainable(\"Na_gNa\")\n
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n

As such, get_parameters() returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:

network.get_parameters()\n
[{'Na_gNa': Array([0.4], dtype=float64)}]\n

If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1]):

network.cell([0,1]).make_trainable(\"axial_resistivity\")\n
Number of newly added trainable parameters: 2. Total number of trainable parameters: 3\n
network.get_parameters()\n
[{'Na_gNa': Array([0.4], dtype=float64)},\n {'axial_resistivity': Array([5000., 5000.], dtype=float64)}]\n

This generated two parameters for the axial resistivitiy, each corresponding to one cell.

"},{"location":"tutorial/06_groups/#summary","title":"Summary","text":"

Groups allow you to organize your simulation in a more intuitive way, and they allow to perform parameter sharing with make_trainable().

"},{"location":"tutorial/07_gradient_descent/","title":"Training biophysical models","text":"

In this tutorial, you will learn how to train biophysical models in Jaxley. This includes the following:

  • compute the gradient with respect to parameters
  • use parameter transformations
  • use multi-level checkpointing
  • define optimizers
  • write dataloaders and parallelize across data

Here is a code snippet which you will learn to understand in this tutorial:

from jax import jit, vmap, value_and_grad\nimport jaxley as jx\nimport jaxley.optimize.transforms as jt\n\nnet = ...  # See tutorial on the basics of `Jaxley`.\n\n# Define which parameters to train.\nnet.cell(\"all\").make_trainable(\"HH_gNa\")\nnet.IonotropicSynapse.make_trainable(\"IonotropicSynapse_gS\")\nparameters = net.get_parameters()\n\n# Define parameter transform and apply it to the parameters.\ntransform = jx.ParamTransform([\n    {\"IonotropicSynapse_gS\": jt.SigmoidTransform(0.0, 1.0)},\n    {\"HH_gNa\":jt.SigmoidTransform(0.0, 1, 0)}\n])\n\nopt_params = transform.inverse(parameters)\n\n# Define simulation and batch it across stimuli.\ndef simulate(params, datapoint):\n    current = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amps=datapoint, dt=0.025, t_max=5.0)\n    data_stimuli = net.cell(0).branch(0).comp(0).data_stimulate(current, None)\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_inds=[20, 20], delta_t=0.025)\n\nbatch_simulate = vmap(simulate, in_axes=(None, 0))\n\n# Define loss function and its gradient.\ndef loss_fn(opt_params, datapoints, label):\n    params = transform.forward(opt_params)\n    voltages = batch_simulate(params, datapoints)\n    return jnp.abs(jnp.mean(voltages) - label)\n\ngrad_fn = jit(value_and_grad(loss_fn, argnums=0))\n\n# Define data and dataloader.\ndata = jnp.asarray(np.random.randn(100, 3))\ndataloader = Dataset.from_tensor_slices((inputs, labels))\ndataloader = dataloader.shuffle(dataloader.cardinality()).batch(4)\n\n# Define the optimizer.\noptimizer = optax.Adam(lr=0.01)\nopt_state = optimizer.init_state(opt_params)\n\nfor epoch in range(10):\n    for batch in dataloader:\n        stimuli = batch[0].numpy()\n        labels = batch[1].numpy()\n        loss, gradient = grad_fn(opt_params, stimuli, labels)\n\n        # Optimizer step.\n        updates, opt_state = optimizer.update(gradient, opt_state)\n        opt_params = optax.apply_updates(opt_params, updates)\n

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, vmap, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Leak\nfrom jaxley.synapses import TanhRateSynapse\nfrom jaxley.connect import fully_connect\n

First, we define a network as you saw in the previous tutorial:

_ = np.random.seed(0)  # For synaptic locations.\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0])\nnet = jx.Network([cell for _ in range(3)])\n\npre = net.cell([0, 1])\npost = net.cell([2])\nfully_connect(pre, post, TanhRateSynapse())\n\n# Change some default values of the tanh synapse.\nnet.TanhRateSynapse.set(\"TanhRateSynapse_x_offset\", -60.0)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_gS\", 1e-3)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_slope\", 0.1)\n\nnet.insert(Leak())\n

This network consists of three neurons arranged in two layers:

net.compute_xyz()\nnet.rotate(180)\nnet.arrange_in_layers(layers=[2, 1], within_layer_offset=100.0, between_layer_offset=100.0)\nfig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = net.vis(ax=ax, detail=\"full\")\n

We consider the last neuron as the output neuron and record the voltage from there:

net.delete_recordings()\nnet.cell(0).branch(0).loc(0.0).record()\nnet.cell(1).branch(0).loc(0.0).record()\nnet.cell(2).branch(0).loc(0.0).record()\n
Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
"},{"location":"tutorial/07_gradient_descent/#defining-a-dataset","title":"Defining a dataset","text":"

We will train this biophysical network on a classification task. The inputs will be values and the label is binary:

inputs = jnp.asarray(np.random.rand(100, 2))\nlabels = jnp.asarray((inputs[:, 0] + inputs[:, 1]) > 1.0)\n
fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(inputs[labels, 0], inputs[labels, 1])\n_ = ax.scatter(inputs[~labels, 0], inputs[~labels, 1])\n

labels = labels.astype(float)\n
"},{"location":"tutorial/07_gradient_descent/#defining-trainable-parameters","title":"Defining trainable parameters","text":"
net.delete_trainables()\n

This follows the same API as .set() seen in the previous tutorial. If you want to use a single parameter for all radiuses in the entire network, do:

net.make_trainable(\"radius\")\n
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n

We can also define parameters for individual compartments. To do this, use the \"all\" key. The following defines a separate parameter the sodium conductance for every compartment in the entire network:

net.cell(\"all\").branch(\"all\").loc(\"all\").make_trainable(\"Leak_gLeak\")\n
Number of newly added trainable parameters: 18. Total number of trainable parameters: 19\n
"},{"location":"tutorial/07_gradient_descent/#making-synaptic-parameters-trainable","title":"Making synaptic parameters trainable","text":"

Synaptic parameters can be made trainable in the exact same way. To use a single parameter for all syanptic conductances in the entire network, do

net.TanhRateSynapse.make_trainable(\"TanhRateSynapse_gS\")\n

Here, we use a different syanptic conductance for all syanpses. This can be done as follows:

net.TanhRateSynapse.edge(\"all\").make_trainable(\"TanhRateSynapse_gS\")\n
Number of newly added trainable parameters: 2. Total number of trainable parameters: 21\n
"},{"location":"tutorial/07_gradient_descent/#running-the-simulation","title":"Running the simulation","text":"

Once all parameters are defined, you have to use .get_parameters() to obtain all trainable parameters. This is also the time to check how many trainable parameters your network has:

params = net.get_parameters()\n

You can now run the simulation with the trainable parameters by passing them to the jx.integrate function.

s = jx.integrate(net, params=params, t_max=10.0)\n
"},{"location":"tutorial/07_gradient_descent/#stimulating-the-network","title":"Stimulating the network","text":"

The network above does not yet get any stimuli. We will use the 2D inputs from the dataset to stimulate the two input neurons. The amplitude of the step current corresponds to the input value. Below is the simulator that defines this:

def simulate(params, inputs):\n    currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10, delta_t=0.025, t_max=10.0)\n\n    data_stimuli = None\n    data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n    data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli, delta_t=0.025)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n

We can also inspect some traces:

traces = batched_simulate(params, inputs[:4])\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(traces[:, 2, :].T)\n

"},{"location":"tutorial/07_gradient_descent/#defining-a-loss-function","title":"Defining a loss function","text":"

Let us define a loss function to be optimized:

def loss(params, inputs, labels):\n    traces = batched_simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[:, 2], axis=1)  # Use the average over time of the output neuron (2) as prediction.\n    prediction = (prediction + 72.0) / 5  # Such that the prediction is roughly in [0, 1].\n    losses = jnp.abs(prediction - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n

And we can use JAX\u2019s inbuilt functions to take the gradient through the entire ODE:

jitted_grad = jit(value_and_grad(loss, argnums=0))\n
value, gradient = jitted_grad(params, inputs[:4], labels[:4])\n
"},{"location":"tutorial/07_gradient_descent/#defining-parameter-transformations","title":"Defining parameter transformations","text":"

Before training, however, we will enforce for all parameters to be within a prespecified range (such that, e.g., conductances can not become negative)

import jaxley.optimize.transforms as jt\n
# Define a function to create appropriate transforms for each parameter\ndef create_transform(name):\n    if name == \"axial_resistivity\":\n        # Must be positive; apply Softplus and scale to match initialization\n        return jt.ChainTransform([jt.SoftplusTransform(0), jt.AffineTransform(5000, 0)])\n    elif name == \"length\":\n        # Apply Softplus and affine transform for the 'length' parameter\n        return jt.ChainTransform([jt.SoftplusTransform(0), jt.AffineTransform(10, 0)])\n    else:\n        # Default to a Softplus transform for other parameters\n        return jt.SoftplusTransform(0)\n\n# Apply the transforms to the parameters\ntransforms = [{k: create_transform(k) for k in param} for param in params]\ntf = jt.ParamTransform(transforms)\n
transform = jx.ParamTransform([{\"radius\": jt.SigmoidTransform(0.1, 5.0)},\n                               {\"Leak_gLeak\":jt.SigmoidTransform(1e-5, 1e-3)},\n                               {\"TanhRateSynapse_gS\" : jt.SigmoidTransform(1e-5, 1e-2)}])\n

With these modify the loss function acocrdingly:

def loss(opt_params, inputs, labels):\n    transform.forward(opt_params)\n\n    traces = batched_simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[:, 2], axis=1)  # Use the average over time of the output neuron (2) as prediction.\n    prediction = (prediction + 72.0)  # Such that the prediction is around 0.\n    losses = jnp.abs(prediction - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n
"},{"location":"tutorial/07_gradient_descent/#using-checkpointing","title":"Using checkpointing","text":"

Checkpointing allows to vastly reduce the memory requirements of training biophysical models (see also JAX\u2019s full tutorial on checkpointing).

t_max = 5.0\ndt = 0.025\n\nlevels = 2\ntime_points = t_max // dt + 2\ncheckpoints = [int(np.ceil(time_points**(1/levels))) for _ in range(levels)]\n

To enable checkpointing, we have to modify the simulate function appropriately and use

jx.integrate(..., checkpoint_inds=checkpoints)\n
as done below:

def simulate(params, inputs):\n    currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10.0, delta_t=dt, t_max=t_max)\n\n    data_stimuli = None\n    data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n    data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_lengths=checkpoints)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n\n\ndef predict(params, inputs):\n    traces = simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[2])  # Use the average over time of the output neuron (2) as prediction.\n    return prediction + 72.0  # Such that the prediction is around 0.\n\nbatched_predict = vmap(predict, in_axes=(None, 0))\n\n\ndef loss(opt_params, inputs, labels):\n    params = transform.forward(opt_params)\n\n    predictions = batched_predict(params, inputs)\n    losses = jnp.abs(predictions - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n\njitted_grad = jit(value_and_grad(loss, argnums=0))\n
"},{"location":"tutorial/07_gradient_descent/#training","title":"Training","text":"

We will use the ADAM optimizer from the optax library to optimize the free parameters (you have to install the package with pip install optax first):

import optax\n
opt_params = transform.inverse(params)\noptimizer = optax.adam(learning_rate=0.01)\nopt_state = optimizer.init(opt_params)\n
"},{"location":"tutorial/07_gradient_descent/#writing-a-dataloader","title":"Writing a dataloader","text":"

Below, we just write our own (very simple) dataloader. Alternatively, you could use the dataloader from any deep learning library such as pytorch or tensorflow:

class Dataset:\n    def __init__(self, inputs: np.ndarray, labels: np.ndarray):\n        \"\"\"Simple Dataloader.\n\n        Args:\n            inputs: Array of shape (num_samples, num_dim)\n            labels: Array of shape (num_samples,)\n        \"\"\"\n        assert len(inputs) == len(labels), \"Inputs and labels must have same length\"\n        self.inputs = inputs\n        self.labels = labels\n        self.num_samples = len(inputs)\n        self._rng_state = None\n        self.batch_size = 1\n\n    def shuffle(self, seed=None):\n        \"\"\"Shuffle the dataset in-place\"\"\"\n        self._rng_state = np.random.get_state()[1][0] if seed is None else seed\n        np.random.seed(self._rng_state)\n        indices = np.random.permutation(self.num_samples)\n        self.inputs = self.inputs[indices]\n        self.labels = self.labels[indices]\n        return self\n\n    def batch(self, batch_size):\n        \"\"\"Create batches of the data\"\"\"\n        self.batch_size = batch_size\n        return self\n\n    def __iter__(self):\n        self.shuffle(seed=self._rng_state)\n        for start in range(0, self.num_samples, self.batch_size):\n            end = min(start + self.batch_size, self.num_samples)\n            yield self.inputs[start:end], self.labels[start:end]\n        self._rng_state += 1\n
"},{"location":"tutorial/07_gradient_descent/#training-loop","title":"Training loop","text":"
batch_size = 4\ndataloader = Dataset(inputs, labels)\ndataloader = dataloader.shuffle(seed=0).batch(batch_size)\n\nfor epoch in range(10):\n    epoch_loss = 0.0\n\n    for batch_ind, batch in enumerate(dataloader):\n        current_batch, label_batch = batch\n        loss_val, gradient = jitted_grad(opt_params, current_batch, label_batch)\n        updates, opt_state = optimizer.update(gradient, opt_state)\n        opt_params = optax.apply_updates(opt_params, updates)\n        epoch_loss += loss_val\n\n    print(f\"epoch {epoch}, loss {epoch_loss}\")\n\nfinal_params = transform.forward(opt_params)\n
epoch 0, loss 25.033223182772293\nepoch 1, loss 21.00894915349165\nepoch 2, loss 15.092242959956026\nepoch 3, loss 9.061544660383163\nepoch 4, loss 6.925509860325612\nepoch 5, loss 6.273630037897756\nepoch 6, loss 6.1757316054693145\nepoch 7, loss 6.135132525725265\nepoch 8, loss 6.145608619185389\nepoch 9, loss 6.135660902068834\n
ntest = 32\npredictions = batched_predict(final_params, inputs[:ntest])\n
fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(labels[:ntest], predictions)\n_ = ax.set_xlabel(\"Label\")\n_ = ax.set_ylabel(\"Prediction\")\n

Indeed, the loss goes down and the network successfully classifies the patterns.

"},{"location":"tutorial/07_gradient_descent/#summary","title":"Summary","text":"

Puh, this was a pretty dense tutorial with a lot of material. You should have learned how to:

  • compute the gradient with respect to parameters
  • use parameter transformations
  • use multi-level checkpointing
  • define optimizers
  • write dataloaders and parallelize across data

This was the last \u201cbasic\u201d tutorial of the Jaxley toolbox. If you want to learn more, check out our Advanced Tutorials. If anything is still unclear please create a discussion. If you find any bugs, please open an issue. Happy coding!

"},{"location":"tutorial/08_importing_morphologies/","title":"Working with morphologies","text":"

In this tutorial, you will learn how to:

  • Load morphologies and make them compatible with Jaxley
  • Use the visualization features
  • Assemble a small network of morphologically accurate cells.

Here is a code snippet which you will learn to understand in this tutorial:

import jaxley as jx\n\ncell = jx.read_swc(\"my_cell.swc\", ncomp=4)\ncell.branch(2).set_ncomp(2)  # Modify the number of compartments of a branch.\n

To work with more complicated morphologies, Jaxley supports importing morphological reconstructions via .swc files. .swc is currently the only supported format. Other formats like .asc need to be converted to .swc first, for example using the BlueBrain\u2019s morph-tool. For more information on the exact specifications of .swc see here.

import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nimport matplotlib.pyplot as plt\n

To work with .swc files, Jaxley implements a custom .swc reader. The reader traces the morphology and identifies all uninterrupted sections. These uninterrupted sections are called branches in Jaxley. Each branch is then further partitioned into compartments.

To demonstrate this, let\u2019s import an example morphology of a Layer 5 pyramidal cell and visualize it.

# import swc file into jx.Cell object\nfname = \"data/morph.swc\"\ncell = jx.read_swc(fname, ncomp=8)  # Use eight compartments per branch.\n\n# print shape (num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
(157, 1256)\n
local_comp_index global_comp_index local_branch_index global_branch_index local_cell_index global_cell_index 0 0 0 0 0 0 0 1 1 1 0 0 0 0 2 2 2 0 0 0 0 3 3 3 0 0 0 0 4 4 4 0 0 0 0 ... ... ... ... ... ... ... 1251 3 1251 156 156 0 0 1252 4 1252 156 156 0 0 1253 5 1253 156 156 0 0 1254 6 1254 156 156 0 0 1255 7 1255 156 156 0 0

1256 rows \u00d7 6 columns

As we can see, this yields a morphology that is approximated by 1256 compartments. Depending on the amount of detail that you need, you can also change the number of compartments in each branch:

cell = jx.read_swc(fname, ncomp=2)\n\n# print shape (num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
(157, 314)\n
local_comp_index global_comp_index local_branch_index global_branch_index local_cell_index global_cell_index 0 0 0 0 0 0 0 1 1 1 0 0 0 0 2 0 2 1 1 0 0 3 1 3 1 1 0 0 4 0 4 2 2 0 0 ... ... ... ... ... ... ... 309 1 309 154 154 0 0 310 0 310 155 155 0 0 311 1 311 155 155 0 0 312 0 312 156 156 0 0 313 1 313 156 156 0 0

314 rows \u00d7 6 columns

The above assigns the same number of compartments to every branch. To use a different number of compartments in individual branches, you can use .set_ncomp():

cell.branch(1).set_ncomp(4)\n

As you can see below, branch 0 has two compartments (because this is what was passed to jx.read_swc(..., ncomp=2)), but branch 1 has four compartments:

cell.branch([0, 1]).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 0.050000 8.119000 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 0.050000 8.119000 5000.0 1.0 -70.0 0 0 1 0 2 0 1 0 3.120779 7.806172 5000.0 1.0 -70.0 0 1 2 1 3 0 1 1 3.120779 7.111231 5000.0 1.0 -70.0 0 1 3 1 4 0 1 2 3.120779 5.652394 5000.0 1.0 -70.0 0 1 4 1 5 0 1 3 3.120779 3.869247 5000.0 1.0 -70.0 0 1 5 1

Once imported the compartmentalized morphology can be viewed using vis.

# visualize the cell\ncell.vis()\nplt.axis(\"off\")\nplt.title(\"L5PC\")\nplt.show()\n

vis can be called on any jx.Module and every View of the module. This means we can also for example use vis to highlight each branch. This can be done by iterating over each branch index and calling cell.branch(i).vis(). Within the loop.

fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n# define colorwheel with 10 colors\ncolors = plt.cm.tab10.colors\nfor i, branch in enumerate(cell.branches):\n    branch.vis(ax=ax, color=colors[i % 10])\nplt.axis(\"off\")\nplt.title(\"Branches\")\nplt.show()\n

While we only use two compartments to approximate each branch in this example, we can see the morphology is still plotted in great detail. This is because we always plot the full .swc reconstruction irrespective of the number of compartments used. The morphology lives seperately in the cell.xyzr attribute in a per branch fashion.

In addition to plotting the full morphology of the cell using points vis(type=\"scatter\") or lines vis(type=\"line\"), Jaxley also supports plotting a detailed morphological vis(type=\"morph\") or approximate compartmental reconstruction vis(type=\"comp\") that correctly considers the thickness of the neurite. Note that \"comp\" plots the lengths of each compartment which is equal to the length of the traced neurite. While neurites can be zigzaggy, the compartments that approximate them are straight lines. This can lead to miss-aligment of the compartment ends. For details see the documentation of vis.

The morphologies can either be projected onto 2D or also rendered in 3D.

# visualize the cell\nfig, ax = plt.subplots(1, 4, figsize=(10, 3), layout=\"constrained\", sharex=True, sharey=True)\ncell.vis(ax=ax[0], type=\"morph\", dims=[0,1])\ncell.vis(ax=ax[1], type=\"comp\", dims=[0,1])\ncell.vis(ax=ax[2], type=\"scatter\", dims=[0,1], s=1)\ncell.vis(ax=ax[3], type=\"line\", dims=[0,1])\nfig.suptitle(\"Comparison of plot types\")\nplt.show()\n

# set to interactive mode\n# %matplotlib notebook\n
# plot in 3D\nfig = plt.figure()\nax = fig.add_subplot(111, projection='3d')\ncell.vis(ax=ax, type=\"line\", dims=[2,0,1])\nax.view_init(elev=20, azim=5)\nplt.show()\n

Since Jaxley supports grouping different branches or compartments together, we can also use the id labels provided by the .swc file to assign group labels to the jx.Cell object.

print(list(cell.groups.keys()))\n\nfig, ax = plt.subplots(1, 1, figsize=(5, 5))\ncolors = plt.cm.tab10.colors\ncell.basal.vis(ax=ax, color=colors[2])\ncell.soma.vis(ax=ax, color=colors[1])\ncell.apical.vis(ax=ax, color=colors[0])\nplt.axis(\"off\")\nplt.title(\"Groups\")\nplt.show()\n
['soma', 'basal', 'apical', 'custom']\n

To build a network of morphologically detailed cells, we can now connect several reconstructed cells together and also visualize the network. However, since all cells are going to have the same center, Jaxley will naively plot all of them on top of each other. To seperate out the cells, we therefore have to move them to a new location first.

net = jx.Network([cell]*5)\njx.connect(net[0,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[4,0,0], IonotropicSynapse())\n\njx.connect(net[1,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[4,0,0], IonotropicSynapse())\n\nnet.rotate(-90)\n\nnet.cell(0).move(0, 300)\nnet.cell(1).move(0, 500)\n\nnet.cell(2).move(900, 200)\nnet.cell(3).move(900, 400)\nnet.cell(4).move(900, 600)\n\nnet.vis()\nplt.axis(\"off\")\nplt.show()\n

Congrats! You have now learned how to vizualize and build networks out of very complex morphologies. To simulate this network, you can follow the steps in the tutorial on how to build a network.

"},{"location":"tutorial/09_advanced_indexing/","title":"Customizing synaptic parameters","text":"

In this tutorial, you will learn how to:

  • use the select() method to fully customize network simulations with Jaxley.
  • use the copy_node_property_to_edges() method to flexibly modify synapses.

Here is a code snippet which you will learn to understand in this tutorial:

net = ...  # See tutorial on Basics of Jaxley.\n\n# Set synaptic conductance of the synapse with index 0 and 1.\nnet.select(edges=[0, 1]).set(\"Ionotropic_gS\", 0.1)\n\n# Set synaptic conductance of all synapses that have cells 3 or 4 as presynaptic neuron.\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [3, 4]\")\nnet.select(edges=df.index).set(\"Ionotropic_gS\", 0.2)\n\n# Set synaptic conductance of all synapses that\n# 1) have cells 2 or 3 as presynaptic neuron and\n# 2) has cell 5 as postsynaptic neuron\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [2, 3]\")\ndf = df.query(\"post_global_cell_index == 5\")\nnet.select(edges=df.index).set(\"Ionotropic_gS\", 0.3)\n

In a previous tutorial you learned how to set parameters of a jx.Network. In that tutorial, we briefly mentioned the select() method which allowed to set individual synapses to particular values. In this tutorial, we will go into detail in how you can fully customize your Jaxley simulation.

Let\u2019s go!

import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.connect import fully_connect\nfrom jaxley.synapses import IonotropicSynapse\n
"},{"location":"tutorial/09_advanced_indexing/#preface-building-the-network","title":"Preface: Building the network","text":"

We first build a network consisting of six neurons, in the same way as we showed in the previous tutorials:

dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=2)\ncell = jx.Cell(branch, parents=[-1, 0])\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n
"},{"location":"tutorial/09_advanced_indexing/#setting-individual-synapse-parameters","title":"Setting individual synapse parameters","text":"

As always, you can use the .edges table to inspect synaptic parameters of the network:

net.edges\n
global_edge_index pre_global_comp_index post_global_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 0 13 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 1 1 0 19 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 2 2 0 20 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 3 3 4 12 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 4 4 4 16 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 5 5 4 21 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 6 6 8 13 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 7 7 8 17 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 8 8 8 21 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0

This table has nine rows, each corresponding to one synapse. This makes sense because we fully connected three neurons (0, 1, 2) to three other neurons (3, 4, 5), giving a total of 3x3=9 synapses.

You can modify parameters of individual synapses as follows:

net.select(edges=[3, 4, 5]).set(\"IonotropicSynapse_gS\", 0.2)\n

Above, we are modifying the synapses with indices [3, 4, 5] (i.e., the indices of the net.edges DataFrame). The resulting values are indeed changed:

net.edges.IonotropicSynapse_gS\n
0    0.0001\n1    0.0001\n2    0.0001\n3    0.2000\n4    0.2000\n5    0.2000\n6    0.0001\n7    0.0001\n8    0.0001\nName: IonotropicSynapse_gS, dtype: float64\n
"},{"location":"tutorial/09_advanced_indexing/#example-1-setting-synaptic-parameters-which-connect-particular-neurons","title":"Example 1: Setting synaptic parameters which connect particular neurons","text":"

This is great, but setting synaptic parameters just by their index can be exhausting, in particular in very large networks. Instead, we would want to, for example, set the maximal conductance of all synapses that connect from cell 0 or 1 to any other neuron.

In Jaxley, such customization can be achieved by filtering the .edges dataframe accordingly, as shown below:

net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1]\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.23)\n
net.edges.IonotropicSynapse_gS\n
0    0.2300\n1    0.2300\n2    0.2300\n3    0.2300\n4    0.2300\n5    0.2300\n6    0.0001\n7    0.0001\n8    0.0001\nName: IonotropicSynapse_gS, dtype: float64\n

Indeed, the first six synapses now have the value 0.23! Let\u2019s look at the individual lines to understand how this worked:

We want to set parameter by cell index. However, by default, the pre- or post-synaptic cell-indices are not listed in net.edges. We can add the cell index to the .edges dataframe by calling .copy_node_property_to_edges():

net.copy_node_property_to_edges(\"global_cell_index\")\n

After this, the pre- and post-synaptic cell indices are listed in net.edges as pre_global_cell_index and post_global_cell_index.

Next, we take .edges, which is a pandas DataFrame:

df = net.edges\n

We then modify this DataFrame to only contain those rows where the global cell index is in 0 or 1:

df = df.query(\"pre_global_cell_index in [0, 1]\")\n

For the above step, you use any column of the DataFrame to filter it (you can see all columns with df.columns). Note that, while we used .query() here, you can really filter the pandas DataFrame however you want. For example, the query above is identical to df = df[df[\"pre_global_cell_index\"].isin([0, 1])].

Finally, we use the .select() method, which returns a subset of the Network at the specified indices. This subset of the network can be modified with .set():

net.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.23)\n

"},{"location":"tutorial/09_advanced_indexing/#example-2-setting-parameters-given-pre-and-post-synaptic-cell-indices","title":"Example 2: Setting parameters given pre- and post-synaptic cell indices","text":"

Say you want to select all synapses that have cells 1 or 2 as presynaptic neuron and cell 4 or 5 as postsynaptic neuron.

net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n

Just like before, we can simply use .query() as already shown above. However, this time, call .query() to twice to filter by pre- and post-synaptic cell indices:

net.copy_node_property_to_edges(\"global_cell_index\")\n\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [1, 2]\")\ndf = df.query(\"post_global_cell_index in [4, 5]\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.3)\n
net.edges.IonotropicSynapse_gS\n
0    0.0001\n1    0.0001\n2    0.0001\n3    0.0001\n4    0.3000\n5    0.3000\n6    0.0001\n7    0.3000\n8    0.3000\nName: IonotropicSynapse_gS, dtype: float64\n
"},{"location":"tutorial/09_advanced_indexing/#example-3-applying-this-strategy-to-cell-level-parameters","title":"Example 3: Applying this strategy to cell level parameters","text":"

You had previously seen that you can modify parameters with, e.g., net.cell(0).set(...). However, if you need more flexibility than this, you can also use the above strategy to modify cell-level parameters:

net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\ndf = net.nodes\ndf = df.query(\"global_cell_index in [0, 1]\")\nnet.select(nodes=df.index).set(\"radius\", 0.1)\n
"},{"location":"tutorial/09_advanced_indexing/#example-4-flexibly-setting-parameters-based-on-their-groups","title":"Example 4: Flexibly setting parameters based on their groups","text":"

If you are using groups, as shown in this tutorial, then you can also use this for querying synapses. To demonstrate this, let\u2019s create a group of excitatory neurons (e.g., cells 0, 3, 5):

# Redefine network.\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.cell([0, 3, 5]).add_to_group(\"exc\")\n

Now, say we want all synapses that start from these excitatory neurons. You can do this as follows:

# First, we have to identify which cells are in the `exc` group.\nindices_of_excitatory_cells = net.exc.nodes[\"global_cell_index\"].unique().tolist()  # [0, 3, 5]\n\n# Then we can proceed as before:\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(f\"pre_global_cell_index in {indices_of_excitatory_cells}\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.4)\n
"},{"location":"tutorial/09_advanced_indexing/#example-5-setting-synaptic-parameters-based-on-properties-of-the-presynaptic-cell","title":"Example 5: Setting synaptic parameters based on properties of the presynaptic cell","text":"

Let\u2019s discuss one more example: Imagine we only want to modify those synapses whose presynaptic compartment has a sodium channel. Let\u2019s first add a sodium channel to some of the cells:

net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.cell(0).branch(0).comp(0).insert(Na())\nnet.cell(2).branch(1).comp(1).insert(Na())\n

Now, let us query which cells have the desired synapses:

df = net.nodes\ndf = df.query(\"Na\")\nindices_of_sodium_compartments = df[\"global_comp_index\"].unique().tolist()\n

indices_of_sodium_compartments lists all compartments which contained sodium:

print(indices_of_sodium_compartments)\n
[0, 11]\n

Then, we can proceed as always and filter for the global pre-synaptic compartment index:

df = net.edges\ndf = df.query(f\"pre_global_comp_index in {indices_of_sodium_compartments}\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.6)\n
net.edges.IonotropicSynapse_gS\n
0    0.6000\n1    0.6000\n2    0.6000\n3    0.0001\n4    0.0001\n5    0.0001\n6    0.0001\n7    0.0001\n8    0.0001\nName: IonotropicSynapse_gS, dtype: float64\n

Indeed, only synapses coming from the first neuron were modified (as its presynaptic compartment contained sodium), in contrast to synapses from neuron 2 (whose presynaptic compartment did not).

"},{"location":"tutorial/09_advanced_indexing/#summary","title":"Summary","text":"

In this tutorial, you learned how to fully customize your Jaxley simulation. This works by querying rows from the .edges DataFrame.

"},{"location":"tutorial/10_advanced_parameter_sharing/","title":"Synaptic parameter sharing","text":"

In this tutorial, you will learn how to:

  • flexibly share parameters of synapses

Here is a code snippet which you will learn to understand in this tutorial:

net = ...  # See tutorial on Basics of Jaxley.\n\n# The same parameter for all synapses\nnet.make_trainable(\"Ionotropic_gS\")\n\n# An individual parameter for every synapse.\nnet.select(edges=\"all\").make_trainable(\"Ionotropic_gS\")\n\n# Share synaptic conductances emerging from the same neurons.\nnet.copy_node_property_to_edges(\"cell_index\")\nsub_net = net.select(edges=[0, 1, 2])\nsub_net.edges[\"controlled_by_param\"] = sub_net.edges[\"pre_global_cell_index\"]\nsub_net.make_trainable(\"Ionotropic_gS\")\n

In a previous tutorial about training networks, we briefly touched on parameter sharing. In this tutorial, we will show you how you can flexibly share parameters within a network.

import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.connect import fully_connect\nfrom jaxley.synapses import IonotropicSynapse\n
"},{"location":"tutorial/10_advanced_parameter_sharing/#preface-building-the-network","title":"Preface: Building the network","text":"

We first build a network consisting of six neurons, in the same way as we showed in the previous tutorials:

dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0])\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n
"},{"location":"tutorial/10_advanced_parameter_sharing/#sharing-parameters-by-modifying-controlled_by_param","title":"Sharing parameters by modifying controlled_by_param","text":"
net.copy_node_property_to_edges(\"global_cell_index\")\n\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1, 2]\")\nsubnetwork = net.select(edges=df.index)\n\ndf = subnetwork.edges\ndf[\"controlled_by_param\"] = df[\"pre_global_cell_index\"]\nsubnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 3. Total number of trainable parameters: 3\n

Let\u2019s look at this line by line. First, we exactly follow the previous tutorial in selecting the synapses which we are interested in training (i.e., the ones whose presynaptic neuron has index 0, 1, 2):

df = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1, 2]\")\nsubnetwork = net.select(edges=df.index)\n

As second step, we enable parameter sharing. This is done by setting the controlled_by_param. Synapses that have the same value in controlled_by_param will be shared. Let\u2019s inspect controlled_by_param before we modify it:

subnetwork.edges[[\"pre_global_cell_index\", \"controlled_by_param\"]]\n
pre_global_cell_index controlled_by_param 0 0 0 1 0 1 2 0 2 3 1 3 4 1 4 5 1 5 6 2 6 7 2 7 8 2 8

Every synapse has a different value. Because of this, no synaptic parameters will be shared. To enable parameter sharing we override the controlled_by_param column with the presynaptic cell index:

df = subnetwork.edges\ndf[\"controlled_by_param\"] = df[\"pre_global_cell_index\"]\n
df[[\"pre_global_cell_index\", \"controlled_by_param\"]]\n
pre_global_cell_index controlled_by_param 0 0 0 1 0 0 2 0 0 3 1 1 4 1 1 5 1 1 6 2 2 7 2 2 8 2 2

Now, all we have to do is to make these synaptic parameters trainable with the make_trainable() method:

subnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 3. Total number of trainable parameters: 6\n

It correctly says that we added three parameters (because we have three cells, and we share individual synaptic parameters). We now have 6 trainable parameters in total (because we already added 3 trainable parameters above).

"},{"location":"tutorial/10_advanced_parameter_sharing/#a-more-involved-example-sharing-by-pre-and-post-synaptic-cell-type","title":"A more involved example: sharing by pre- and post-synaptic cell type","text":"

As an example, consider the following: We have a fully connected network of six cells. Each cell falls into one of three cell types:

from typing import Union, List\n
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell(\"all\"), net.cell(\"all\"), IonotropicSynapse())\n\nnet.cell([0, 1]).add_to_group(\"exc\")\nnet.cell([2, 3]).add_to_group(\"inh\")\nnet.cell([4, 5]).add_to_group(\"unknown\")\n

We want to make all synapses that start from excitatory or inhibitory neurons trainable. In addition, we want to use the same parameter for synapses if they have the same pre- and post-synaptic cell type.

To achieve this, we will first want a column in net.nodes which indicates the cell type.

for group, inds in net.groups.items():\n    net.nodes.loc[inds, \"cell_type\"] = group\n
net.nodes[\"cell_type\"]\n
0         exc\n1         exc\n2         exc\n3         exc\n4         exc\n5         exc\n6         exc\n7         exc\n8         inh\n9         inh\n10        inh\n11        inh\n12        inh\n13        inh\n14        inh\n15        inh\n16    unknown\n17    unknown\n18    unknown\n19    unknown\n20    unknown\n21    unknown\n22    unknown\n23    unknown\nName: cell_type, dtype: object\n

The cell_type is now part of the net.nodes. However, we would like to do parameter sharing of synapses based on the pre- and post-synaptic node values. To do so, we import the cell_type column into net.edges. To do this, we use the .copy_node_property_to_edges() which the name of the property you are copying from nodes:

net.copy_node_property_to_edges(\"cell_type\")\n

After this, you have columns in the .edges which indicate the pre- and post-synaptic cell type:

net.edges[[\"pre_cell_type\", \"post_cell_type\"]]\n
pre_cell_type post_cell_type 0 exc exc 1 exc exc 2 exc inh 3 exc inh 4 exc unknown 5 exc unknown 6 exc exc 7 exc exc 8 exc inh 9 exc inh 10 exc unknown 11 exc unknown 12 inh exc 13 inh exc 14 inh inh 15 inh inh 16 inh unknown 17 inh unknown 18 inh exc 19 inh exc 20 inh inh 21 inh inh 22 inh unknown 23 inh unknown 24 unknown exc 25 unknown exc 26 unknown inh 27 unknown inh 28 unknown unknown 29 unknown unknown 30 unknown exc 31 unknown exc 32 unknown inh 33 unknown inh 34 unknown unknown 35 unknown unknown

Next, we specify which parts of the network we actually want to change (in this case, all synapses which have excitatory or inhibitory presynaptic neurons):

df = net.edges\ndf = df.query(f\"pre_cell_type in ['exc', 'inh']\")\nprint(f\"There are {len(df)} synapses to be changed.\")\n\nsubnetwork = net.select(edges=df.index)\n
There are 24 synapses to be changed.\n

As the last step, we again have to specify parameter sharing by setting controlled_by_param. In this case, we want to share parameters that have the same pre- and post-synaptic neuron. We achieve this by grouping the synpases by their pre- and post-synaptic cell type (see pd.DataFrame.groupby for details):

# Step 6: use groupby to specify parameter sharing and make the parameters trainable.\nsubnetwork.edges[\"controlled_by_param\"] = subnetwork.edges.groupby([\"pre_cell_type\", \"post_cell_type\"]).ngroup()\nsubnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 6. Total number of trainable parameters: 6\n

This created six trainable parameters, which makes sense as we have two types of pre-synaptic neurons (excitatory and inhibitory) and each has three options for the postsynaptic neuron (pre, post, unknown).

"},{"location":"tutorial/10_advanced_parameter_sharing/#summary","title":"Summary","text":"

In this tutorial, you learned how you can flexibly share synaptic parameters. This works by first using select() to identify which synapses to make trainable, and by then modifying controlled_by_param to customize parameter sharing.

"}]} \ No newline at end of file +{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Home","text":"

The official documentation for Jaxley has moved to jaxley.readthedocs.io. The website you are currently on will be taken down in the future.

Jaxley is a differentiable simulator for biophysical neuron models in JAX. Its key features are:

  • automatic differentiation, allowing gradient-based optimization of thousands of parameters
  • support for CPU, GPU, or TPU without any changes to the code
  • jit-compilation, making it as fast as other packages while being fully written in python
  • backward-Euler solver for stable numerical solution of multicompartment neurons
  • elegant mechanisms for parameter sharing
"},{"location":"#getting-started","title":"Getting started","text":"

Jaxley allows to simulate biophysical neuron models on CPU, GPU, or TPU:

import matplotlib.pyplot as plt\nfrom jax import config\n\nimport jaxley as jx\nfrom jaxley.channels import HH\n\nconfig.update(\"jax_platform_name\", \"cpu\")  # Or \"gpu\" / \"tpu\".\n\ncell = jx.Cell()  # Define cell.\ncell.insert(HH())  # Insert channels.\n\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=10.0)\ncell.stimulate(current)  # Stimulate with step current.\ncell.record(\"v\")  # Record voltage.\n\nv = jx.integrate(cell)  # Run simulation.\nplt.plot(v.T)  # Plot voltage trace.\n

If you want to learn more, we have tutorials on how to:

  • simulate morphologically detailed neurons
  • simulate networks of such neurons
  • set parameters of cells and networks
  • speed up simulations with GPUs and jit
  • define your own channels and synapses
  • define groups
  • read and handle SWC files
  • compute the gradient and train biophysical models
"},{"location":"#installation","title":"Installation","text":"

Jaxley is available on pypi:

pip install jaxley\n
This will install Jaxley with CPU support. If you want GPU support, follow the instructions on the JAX github repository to install JAX with GPU support (in addition to installing Jaxley). For example, for NVIDIA GPUs, run
pip install -U \"jax[cuda12]\"\n

"},{"location":"#feedback-and-contributions","title":"Feedback and Contributions","text":"

We welcome any feedback on how Jaxley is working for your neuron models and are happy to receive bug reports, pull requests and other feedback (see contribute). We wish to maintain a positive community, please read our Code of Conduct.

"},{"location":"#license","title":"License","text":"

Apache License Version 2.0 (Apache-2.0)

"},{"location":"#citation","title":"Citation","text":"

If you use Jaxley, consider citing the corresponding paper:

@article{deistler2024differentiable,\n  doi = {10.1101/2024.08.21.608979},\n  year = {2024},\n  publisher = {Cold Spring Harbor Laboratory},\n  author = {Deistler, Michael and Kadhim, Kyra L. and Pals, Matthijs and Beck, Jonas and Huang, Ziwei and Gloeckler, Manuel and Lappalainen, Janne K. and Schr{\\\"o}der, Cornelius and Berens, Philipp and Gon{\\c c}alves, Pedro J. and Macke, Jakob H.},\n  title = {Differentiable simulation enables large-scale training of detailed biophysical models of neural dynamics},\n  journal = {bioRxiv}\n}\n
"},{"location":"code_of_conduct/","title":"Contributor Covenant Code of Conduct","text":""},{"location":"code_of_conduct/#our-pledge","title":"Our Pledge","text":"

We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation.

We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.

"},{"location":"code_of_conduct/#our-standards","title":"Our Standards","text":"

Examples of behavior that contributes to a positive environment for our community include:

  • Demonstrating empathy and kindness toward other people
  • Being respectful of differing opinions, viewpoints, and experiences
  • Giving and gracefully accepting constructive feedback
  • Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
  • Focusing on what is best not just for us as individuals, but for the overall community

Examples of unacceptable behavior include:

  • The use of sexualized language or imagery, and sexual attention or advances of any kind
  • Trolling, insulting or derogatory comments, and personal or political attacks
  • Public or private harassment
  • Publishing others\u2019 private information, such as a physical or email address, without their explicit permission
  • Other conduct which could reasonably be considered inappropriate in a professional setting
"},{"location":"code_of_conduct/#enforcement-responsibilities","title":"Enforcement Responsibilities","text":"

Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.

Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.

"},{"location":"code_of_conduct/#scope","title":"Scope","text":"

This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.

"},{"location":"code_of_conduct/#enforcement","title":"Enforcement","text":"

Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting jaxley developer Michael Deistler via email (michael.deistler@uni-tuebingen.de). All complaints will be reviewed and investigated promptly and fairly.

All community leaders are obligated to respect the privacy and security of the reporter of any incident.

"},{"location":"code_of_conduct/#enforcement-guidelines","title":"Enforcement Guidelines","text":"

Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:

"},{"location":"code_of_conduct/#1-correction","title":"1. Correction","text":"

Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.

Consequence: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.

"},{"location":"code_of_conduct/#2-warning","title":"2. Warning","text":"

Community Impact: A violation through a single incident or series of actions.

Consequence: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.

"},{"location":"code_of_conduct/#3-temporary-ban","title":"3. Temporary Ban","text":"

Community Impact: A serious violation of community standards, including sustained inappropriate behavior.

Consequence: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.

"},{"location":"code_of_conduct/#4-permanent-ban","title":"4. Permanent Ban","text":"

Community Impact: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.

Consequence: A permanent ban from any sort of public interaction within the community.

"},{"location":"code_of_conduct/#attribution","title":"Attribution","text":"

This Code of Conduct is adapted from the Contributor Covenant, version 2.1, available at https://www.contributor-covenant.org/version/2/1/code_of_conduct.html.

Community Impact Guidelines were inspired by Mozilla\u2019s code of conduct enforcement ladder.

For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.

"},{"location":"contribute/","title":"Guide","text":""},{"location":"contribute/#user-experiences-bugs-and-feature-requests","title":"User experiences, bugs, and feature requests","text":"

To report bugs and suggest features (including better documentation), please head over to issues on GitHub.

"},{"location":"contribute/#code-contributions","title":"Code contributions","text":"

In general, we use pull requests to make changes to Jaxley. So, if you are planning to make a contribution, please fork, create a feature branch and then make a PR from your feature branch to the upstream Jaxley (details).

"},{"location":"contribute/#development-environment","title":"Development environment","text":"

Clone the repo and install via setup.py using pip install -e \".[dev]\" (the dev flag installs development and testing dependencies).

"},{"location":"contribute/#style-conventions","title":"Style conventions","text":"

For docstrings and comments, we use Google Style.

Code needs to pass through the following tools, which are installed alongside Jaxley:

black: Automatic code formatting for Python. You can run black manually from the console using black . in the top directory of the repository, which will format all files.

isort: Used to consistently order imports. You can run isort manually from the console using isort in the top directory.

black and isort are checked as part of our CI actions. If these checks fail please make sure you have installed the latest versions for each of them and run them locally.

"},{"location":"contribute/#online-documentation","title":"Online documentation","text":"

Most of the documentation is written in markdown (basic markdown guide).

You can directly fix mistakes and suggest clearer formulations in markdown files simply by initiating a PR on through GitHub. Click on documentation file and look for the little pencil at top right.

"},{"location":"credits/","title":"Credits","text":"

Jaxley is a collaborative project between the groups of Jakob Macke (Uni T\u00fcbingen), Pedro Gon\u00e7alves (KU Leuven / NERF), and Philipp Berens (Uni T\u00fcbingen).

"},{"location":"credits/#license","title":"License","text":"

Jaxley is licensed under the Apache License Version 2.0 (Apache-2.0) and

Copyright (C) 2024 Michael Deistler, Jakob H. Macke, Pedro J. Goncalves, Philipp Berens.

"},{"location":"credits/#important-dependencies-and-prior-art","title":"Important dependencies and prior art","text":"
  • We greatly benefited from previous toolboxes for simulating multicompartment neurons, in particular NEURON.
"},{"location":"credits/#funding","title":"Funding","text":"

This work was supported by the German Research Foundation (DFG) through Germany\u2019s Excellence Strategy (EXC 2064 \u2013 Project number 390727645) and the CRC 1233 \u201cRobust Vision\u201d, the German Federal Ministry of Education and Research (Tu\u0308bingen AI Center, FKZ: 01IS18039A), the \u2018Certification and Foundations of Safe Machine Learning Systems in Healthcare\u2019 project funded by the Carl Zeiss Foundation, and the European Union (ERC, \u201cDeepCoMechTome\u201d, ref. 101089288, \u201cNextMechMod\u201d, ref. 101039115).

"},{"location":"faq/","title":"Frequently asked questions","text":"
  • What kinds of models can be implemented in Jaxley?
  • What units does Jaxley use?
  • How can I save and load cells and networks?

See also the discussion page and the issue tracker on the Jaxley GitHub repository for recent questions and problems.

"},{"location":"install/","title":"Installation","text":""},{"location":"install/#install-the-most-recent-stable-version","title":"Install the most recent stable version","text":"

Jaxley is available on PyPI:

pip install jaxley\n
This will install Jaxley with CPU support. If you want GPU support, follow the instructions on the JAX github repository to install JAX with GPU support (in addition to installing Jaxley). For example, for NVIDIA GPUs, run
pip install -U \"jax[cuda12]\"\n

"},{"location":"install/#install-from-source","title":"Install from source","text":"

You can also install Jaxley from source:

git clone https://github.com/jaxleyverse/jaxley.git\ncd jaxley\npip install -e .\n

Note that pip>=21.3 is required to install the editable version with pyproject.toml see pip docs.

"},{"location":"faq/question_01/","title":"What units does Jaxley use?","text":"

Jaxley uses the same units as the NEURON simulator, which are listed here.

"},{"location":"faq/question_02/","title":"How can I save and load cells and networks?","text":"

All modules (i.e., compartments, branches, cells, and networks) in Jaxley can be saved and loaded with pickle:

import jaxley as jx\nimport pickle\n\n# ... define network, cell, etc.\nnetwork = jx.Network([cell1, cell2])\n\n# Save.\nwith open(\"path/to/file.pkl\", \"wb\") as handle:\n    pickle.dump(network, handle)\n\n# Load.\nwith open(\"path/to/file.pkl\", \"rb\") as handle:\n    network = pickle.load(handle)\n

"},{"location":"faq/question_03/","title":"What kinds of models can be implemented in Jaxley?","text":"

Jaxley focuses on biophysical, Hodgkin-Huxley-type models. You can think of Jaxley like the NEURON simulator written in JAX.

Jaxley allows to simulate the following types of models, as well as networks thereof:

  • single-compartment (point neuron) Hodgkin-Huxley models
  • multi-compartment Hodgkin-Huxley models
  • rate-based neuron models

For all of these models, Jaxley is flexible and accurate. For example, it can flexibly add new channel models, use different kinds of synapses (conductance-based, tanh, \u2026), and it can insert different kinds of channels in different branches (or compartments) within single cells. Like NEURON, Jaxley implements a backward-Euler solver for stable numerical solution of multi-compartment neurons.

However, Jaxley does not implement the following types of models:

  • leaky-integrate and fire neurons
  • Ishikevich neuron models
  • etc\u2026
"},{"location":"reference/connect/","title":"Connecting Cells","text":""},{"location":"reference/connect/#jaxley.connect.connect","title":"connect(pre, post, synapse_type)","text":"

Connect two compartments with a chemical synapse.

The pre- and postsynaptic compartments must be different compartments of the same network.

Parameters:

Name Type Description Default pre View

View of the presynaptic compartment.

required post View

View of the postsynaptic compartment.

required synapse_type Synapse

The synapse to append

required Source code in jaxley/connect.py
def connect(\n    pre: \"View\",\n    post: \"View\",\n    synapse_type: \"Synapse\",\n):\n    \"\"\"Connect two compartments with a chemical synapse.\n\n    The pre- and postsynaptic compartments must be different compartments of the\n    same network.\n\n    Args:\n        pre: View of the presynaptic compartment.\n        post: View of the postsynaptic compartment.\n        synapse_type: The synapse to append\n    \"\"\"\n    assert is_same_network(\n        pre, post\n    ), \"Pre and post compartments must be part of the same network.\"\n\n    pre.base._append_multiple_synapses(pre.nodes, post.nodes, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.connectivity_matrix_connect","title":"connectivity_matrix_connect(pre_cell_view, post_cell_view, synapse_type, connectivity_matrix)","text":"

Appends multiple connections which build a custom connected network.

Connects pre- and postsynaptic cells according to a custom connectivity matrix. Entries > 0 in the matrix indicate a connection between the corresponding cells. Connections are from branch 0 location 0 to a randomly chosen branch and loc.

Parameters:

Name Type Description Default pre_cell_view View

View of the presynaptic cell.

required post_cell_view View

View of the postsynaptic cell.

required synapse_type Synapse

The synapse to append.

required connectivity_matrix ndarray[bool]

A boolean matrix indicating the connections between cells.

required Source code in jaxley/connect.py
def connectivity_matrix_connect(\n    pre_cell_view: \"View\",\n    post_cell_view: \"View\",\n    synapse_type: \"Synapse\",\n    connectivity_matrix: np.ndarray[bool],\n):\n    \"\"\"Appends multiple connections which build a custom connected network.\n\n    Connects pre- and postsynaptic cells according to a custom connectivity matrix.\n    Entries > 0 in the matrix indicate a connection between the corresponding cells.\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n        connectivity_matrix: A boolean matrix indicating the connections between cells.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    pre_cell_inds = pre_cell_view._cells_in_view\n    post_cell_inds = post_cell_view._cells_in_view\n    # setting scope ensure that this works indep of current scope\n    pre_nodes = pre_cell_view.scope(\"local\").branch(0).comp(0).nodes\n    pre_nodes[\"index\"] = pre_nodes.index\n    pre_cell_nodes = pre_nodes.set_index(\"global_cell_index\")\n\n    assert connectivity_matrix.shape == (\n        len(pre_cell_inds),\n        len(post_cell_inds),\n    ), \"Connectivity matrix must have shape (num_pre, num_post).\"\n    assert connectivity_matrix.dtype == bool, \"Connectivity matrix must be boolean.\"\n\n    # get connection pairs from connectivity matrix\n    from_idx, to_idx = np.where(connectivity_matrix)\n    pre_cell_inds = pre_cell_inds[from_idx]\n    post_cell_inds = post_cell_inds[to_idx]\n\n    # Sample random postsynaptic compartments (global comp indices).\n    global_post_indices = np.hstack(\n        [\n            sample_comp(post_cell_view.scope(\"global\").cell(cell_idx))\n            for cell_idx in post_cell_inds\n        ]\n    )\n    post_rows = post_cell_view.nodes.loc[global_post_indices]\n\n    # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n    global_pre_indices = pre_cell_nodes.loc[pre_cell_inds, \"index\"].to_numpy()\n    pre_rows = pre_cell_view.select(nodes=global_pre_indices).nodes\n\n    pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.fully_connect","title":"fully_connect(pre_cell_view, post_cell_view, synapse_type)","text":"

Appends multiple connections which build a fully connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch and loc.

Parameters:

Name Type Description Default pre_cell_view View

View of the presynaptic cell.

required post_cell_view View

View of the postsynaptic cell.

required synapse_type Synapse

The synapse to append.

required Source code in jaxley/connect.py
def fully_connect(\n    pre_cell_view: \"View\",\n    post_cell_view: \"View\",\n    synapse_type: \"Synapse\",\n):\n    \"\"\"Appends multiple connections which build a fully connected layer.\n\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    num_pre = len(pre_cell_view._cells_in_view)\n    num_post = len(post_cell_view._cells_in_view)\n\n    # Infer indices of (random) postsynaptic compartments.\n    global_post_indices = (\n        post_cell_view.nodes.groupby(\"global_cell_index\")\n        .sample(num_pre, replace=True)\n        .index.to_numpy()\n    )\n    global_post_indices = global_post_indices.reshape((-1, num_pre), order=\"F\").ravel()\n    post_rows = post_cell_view.nodes.loc[global_post_indices]\n\n    # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n    pre_rows = pre_cell_view.scope(\"local\").branch(0).comp(0).nodes.copy()\n    # Repeat rows `num_post` times. See SO 50788508.\n    pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)\n\n    pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.is_same_network","title":"is_same_network(pre, post)","text":"

Check if views are from the same network.

Source code in jaxley/connect.py
def is_same_network(pre: \"View\", post: \"View\") -> bool:\n    \"\"\"Check if views are from the same network.\"\"\"\n    is_in_net = \"network\" in pre.base.__class__.__name__.lower()\n    is_in_same_net = pre.base is post.base\n    return is_in_net and is_in_same_net\n
"},{"location":"reference/connect/#jaxley.connect.sample_comp","title":"sample_comp(cell_view, num=1, replace=True)","text":"

Sample a compartment from a cell.

Returns View with shape (num, num_cols).

Source code in jaxley/connect.py
def sample_comp(cell_view: \"View\", num: int = 1, replace=True) -> \"CompartmentView\":\n    \"\"\"Sample a compartment from a cell.\n\n    Returns View with shape (num, num_cols).\"\"\"\n    return np.random.choice(cell_view._comps_in_view, num, replace=replace)\n
"},{"location":"reference/connect/#jaxley.connect.sparse_connect","title":"sparse_connect(pre_cell_view, post_cell_view, synapse_type, p)","text":"

Appends multiple connections which build a sparse, randomly connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch and loc.

Parameters:

Name Type Description Default pre_cell_view View

View of the presynaptic cell.

required post_cell_view View

View of the postsynaptic cell.

required synapse_type Synapse

The synapse to append.

required p float

Probability of connection.

required Source code in jaxley/connect.py
def sparse_connect(\n    pre_cell_view: \"View\",\n    post_cell_view: \"View\",\n    synapse_type: \"Synapse\",\n    p: float,\n):\n    \"\"\"Appends multiple connections which build a sparse, randomly connected layer.\n\n    Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n    Args:\n        pre_cell_view: View of the presynaptic cell.\n        post_cell_view: View of the postsynaptic cell.\n        synapse_type: The synapse to append.\n        p: Probability of connection.\n    \"\"\"\n    # Get pre- and postsynaptic cell indices.\n    pre_cell_inds = pre_cell_view._cells_in_view\n    post_cell_inds = post_cell_view._cells_in_view\n    num_pre = len(pre_cell_inds)\n    num_post = len(post_cell_inds)\n\n    num_connections = np.random.binomial(num_pre * num_post, p)\n    pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)\n    post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)\n\n    # Sort the synapses only for convenience of inspecting `.edges`.\n    sorting = np.argsort(pre_syn_neurons)\n    pre_syn_neurons = pre_syn_neurons[sorting]\n    post_syn_neurons = post_syn_neurons[sorting]\n\n    # Post-synapse is a randomly chosen branch and compartment.\n    global_post_indices = [\n        sample_comp(post_cell_view.scope(\"global\").cell(cell_idx))\n        for cell_idx in post_syn_neurons\n    ]\n    global_post_indices = (\n        np.hstack(global_post_indices) if len(global_post_indices) > 1 else []\n    )\n    post_rows = post_cell_view.base.nodes.loc[global_post_indices]\n\n    # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n    global_pre_indices = pre_cell_view.base._cumsum_ncomp_per_cell[pre_syn_neurons]\n    pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices]\n\n    if len(pre_rows) > 0:\n        pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/integration/","title":"Simulation","text":""},{"location":"reference/integration/#jaxley.integrate.add_clamps","title":"add_clamps(externals, external_inds, data_clamps=None)","text":"

Adds clamps to the external inputs.

Parameters:

Name Type Description Default externals Dict

Current external inputs.

required external_inds Dict

Current external indices.

required data_clamps Optional[Tuple[str, ndarray, DataFrame]]

Additional data clamps. Defaults to None.

None

Returns:

Type Description Tuple[Dict, Dict]

Tuple[Dict, Dict]: Updated external inputs and indices.

Source code in jaxley/integrate.py
def add_clamps(\n    externals: Dict,\n    external_inds: Dict,\n    data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,\n) -> Tuple[Dict, Dict]:\n    \"\"\"Adds clamps to the external inputs.\n\n    Args:\n        externals (Dict): Current external inputs.\n        external_inds (Dict): Current external indices.\n        data_clamps (Optional[Tuple[str, jnp.ndarray, pd.DataFrame]], optional): Additional data clamps. Defaults to None.\n\n    Returns:\n        Tuple[Dict, Dict]: Updated external inputs and indices.\n    \"\"\"\n    # If a clamp is inserted, add it to the external inputs.\n    if data_clamps is not None:\n        state_name, clamps, inds = data_clamps\n        if state_name in externals.keys():\n            externals[state_name] = jnp.concatenate([externals[state_name], clamps])\n            external_inds[state_name] = jnp.concatenate(\n                [external_inds[state_name], inds.index.to_numpy()]\n            )\n        else:\n            externals[state_name] = clamps\n            external_inds[state_name] = inds.index.to_numpy()\n\n    return externals, external_inds\n
"},{"location":"reference/integration/#jaxley.integrate.add_stimuli","title":"add_stimuli(externals, external_inds, data_stimuli=None)","text":"

Extends the external inputs with the stimuli.

Parameters:

Name Type Description Default externals Dict

Current external inputs.

required external_inds Dict

Current external indices.

required data_stimuli Optional[Tuple[ndarray, DataFrame]]

Additional data stimuli. Defaults to None.

None

Returns:

Type Description Tuple[Dict, Dict]

Tuple[Dict, Dict]: Updated external inputs and indices.

Source code in jaxley/integrate.py
def add_stimuli(\n    externals: Dict,\n    external_inds: Dict,\n    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n) -> Tuple[Dict, Dict]:\n    \"\"\"Extends the external inputs with the stimuli.\n\n    Args:\n        externals (Dict): Current external inputs.\n        external_inds (Dict): Current external indices.\n        data_stimuli (Optional[Tuple[jnp.ndarray, pd.DataFrame]], optional): Additional data stimuli. Defaults to None.\n\n    Returns:\n        Tuple[Dict, Dict]: Updated external inputs and indices.\n    \"\"\"\n    # If stimulus is inserted, add it to the external inputs.\n    if \"i\" in externals.keys() or data_stimuli is not None:\n        if \"i\" in externals.keys():\n            if data_stimuli is not None:\n                externals[\"i\"] = jnp.concatenate([externals[\"i\"], data_stimuli[1]])\n                external_inds[\"i\"] = jnp.concatenate(\n                    [external_inds[\"i\"], data_stimuli[2].index.to_numpy()]\n                )\n        else:\n            externals[\"i\"] = data_stimuli[1]\n            external_inds[\"i\"] = data_stimuli[2].index.to_numpy()\n\n    return externals, external_inds\n
"},{"location":"reference/integration/#jaxley.integrate.build_init_and_step_fn","title":"build_init_and_step_fn(module, voltage_solver='jaxley.stone', solver='bwd_euler')","text":"

This function returns the init_fn and step_fn which initialize the parameters and states of the neuron model and then step through the model

Parameters:

Name Type Description Default module Module

A Module object that e.g. a cell.

required voltage_solver str

Voltage solver used in step. Defaults to \u201cjaxley.stone\u201d.

'jaxley.stone' solver str

ODE solver. Defaults to \u201cbwd_euler\u201d.

'bwd_euler'

Returns:

Type Description Tuple[Callable, Callable]

init_fn, step_fn: Functions that initialize the state and parameters, and perform a single integration step, respectively.

Source code in jaxley/integrate.py
def build_init_and_step_fn(\n    module: Module,\n    voltage_solver: str = \"jaxley.stone\",\n    solver: str = \"bwd_euler\",\n) -> Tuple[Callable, Callable]:\n    \"\"\"This function returns the `init_fn` and `step_fn` which initialize the\n    parameters and states of the neuron model and then step through the model\n\n    Args:\n        module (Module): A `Module` object that e.g. a cell.\n        voltage_solver (str, optional): Voltage solver used in step. Defaults to \"jaxley.stone\".\n        solver (str, optional): ODE solver. Defaults to \"bwd_euler\".\n\n    Returns:\n        init_fn, step_fn: Functions that initialize the state and parameters, and perform\n            a single integration step, respectively.\n    \"\"\"\n    # Initialize the external inputs and their indices.\n    external_inds = module.external_inds.copy()\n\n    def init_fn(\n        params: List[Dict[str, jnp.ndarray]],\n        all_states: Optional[Dict] = None,\n        param_state: Optional[List[Dict]] = None,\n        delta_t: float = 0.025,\n    ) -> Tuple[Dict, Dict]:\n        \"\"\"Initializes the parameters and states of the neuron model.\n\n        Args:\n            params (List[Dict[str, jnp.ndarray]]): List of trainable parameters.\n            all_states (Optional[Dict], optional): State if alread initialized. Defaults to None.\n            param_state (Optional[List[Dict]], optional): Parameters returned by `data_set`.. Defaults to None.\n            delta_t (float, optional): Step size. Defaults to 0.025.\n\n        Returns:\n            Tuple[Dict, Dict]: All states and parameters.\n        \"\"\"\n        # Make the `trainable_params` of the same shape as the `param_state`, such that\n        # they can be processed together by `get_all_parameters`.\n        pstate = params_to_pstate(params, module.indices_set_by_trainables)\n        if param_state is not None:\n            pstate += param_state\n\n        all_params = module.get_all_parameters(pstate, voltage_solver=voltage_solver)\n        all_states = (\n            module.get_all_states(pstate, all_params, delta_t)\n            if all_states is None\n            else all_states\n        )\n        return all_states, all_params\n\n    def step_fn(\n        all_states: Dict,\n        all_params: Dict,\n        externals: Dict,\n        external_inds: Dict = external_inds,\n        delta_t: float = 0.025,\n    ) -> Dict:\n        \"\"\"Performs a single integration step with step size delta_t.\n\n        Args:\n            all_states (Dict): Current state of the neuron model.\n            all_params (Dict): Current parameters of the neuron model.\n            externals (Dict): External inputs.\n            external_inds (Dict, optional): External indices. Defaults to `module.external_inds`.\n            delta_t (float, optional): Time step. Defaults to 0.025.\n\n        Returns:\n            Dict: Updated states.\n        \"\"\"\n        state = all_states\n        state = module.step(\n            state,\n            delta_t,\n            external_inds,\n            externals,\n            params=all_params,\n            solver=solver,\n            voltage_solver=voltage_solver,\n        )\n        return state\n\n    return init_fn, step_fn\n
"},{"location":"reference/integration/#jaxley.integrate.integrate","title":"integrate(module, params=[], *, param_state=None, data_stimuli=None, data_clamps=None, t_max=None, delta_t=0.025, solver='bwd_euler', voltage_solver='jaxley.stone', checkpoint_lengths=None, all_states=None, return_states=False)","text":"

Solves ODE and simulates neuron model.

Parameters:

Name Type Description Default params List[Dict[str, ndarray]]

Trainable parameters returned by get_parameters().

[] param_state Optional[List[Dict]]

Parameters returned by data_set.

None data_stimuli Optional[Tuple[ndarray, DataFrame]]

Outputs of .data_stimulate(), only needed if stimuli change across function calls.

None data_clamps Optional[Tuple[str, ndarray, DataFrame]]

Outputs of .data_clamp(), only needed if clamps change across function calls.

None t_max Optional[float]

Duration of the simulation in milliseconds. If t_max is greater than the length of the stimulus input, the stimulus will be padded at the end with zeros. If t_max is smaller, then the stimulus with be truncated.

None delta_t float

Time step of the solver in milliseconds.

0.025 solver str

Which ODE solver to use. Either of [\u201cfwd_euler\u201d, \u201cbwd_euler\u201d, \u201ccrank_nicolson\u201d].

'bwd_euler' tridiag_solver

Algorithm to solve tridiagonal systems. The different options only affect bwd_euler and crank_nicolson solvers. Either of [\u201cstone\u201d, \u201cthomas\u201d], where stone is much faster on GPU for long branches with many compartments and thomas is slightly faster on CPU (thomas is used in NEURON).

required checkpoint_lengths Optional[List[int]]

Number of timesteps at every level of checkpointing. The prod(checkpoint_lengths) must be larger or equal to the desired number of simulated timesteps. Warning: the simulation is run for prod(checkpoint_lengths) timesteps, and the result is posthoc truncated to the desired simulation length. Therefore, a poor choice of checkpoint_lengths can lead to longer simulation time. If None, no checkpointing is applied.

None all_states Optional[Dict]

An optional initial state that was returned by a previous jx.integrate(..., return_states=True) run. Overrides potentially trainable initial states.

None return_states bool

If True, it returns all states such that the current state of the Module can be set with set_states.

False Source code in jaxley/integrate.py
def integrate(\n    module: Module,\n    params: List[Dict[str, jnp.ndarray]] = [],\n    *,\n    param_state: Optional[List[Dict]] = None,\n    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n    data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,\n    t_max: Optional[float] = None,\n    delta_t: float = 0.025,\n    solver: str = \"bwd_euler\",\n    voltage_solver: str = \"jaxley.stone\",\n    checkpoint_lengths: Optional[List[int]] = None,\n    all_states: Optional[Dict] = None,\n    return_states: bool = False,\n) -> jnp.ndarray:\n    \"\"\"\n    Solves ODE and simulates neuron model.\n\n    Args:\n        params: Trainable parameters returned by `get_parameters()`.\n        param_state: Parameters returned by `data_set`.\n        data_stimuli: Outputs of `.data_stimulate()`, only needed if stimuli change\n            across function calls.\n        data_clamps: Outputs of `.data_clamp()`, only needed if clamps change across\n            function calls.\n        t_max: Duration of the simulation in milliseconds. If `t_max` is greater than\n            the length of the stimulus input, the stimulus will be padded at the end\n            with zeros. If `t_max` is smaller, then the stimulus with be truncated.\n        delta_t: Time step of the solver in milliseconds.\n        solver: Which ODE solver to use. Either of [\"fwd_euler\", \"bwd_euler\",\n            \"crank_nicolson\"].\n        tridiag_solver: Algorithm to solve tridiagonal systems. The  different options\n            only affect `bwd_euler` and `crank_nicolson` solvers. Either of [\"stone\",\n            \"thomas\"], where `stone` is much faster on GPU for long branches\n            with many compartments and `thomas` is slightly faster on CPU (`thomas` is\n            used in NEURON).\n        checkpoint_lengths: Number of timesteps at every level of checkpointing. The\n            `prod(checkpoint_lengths)` must be larger or equal to the desired number of\n            simulated timesteps. Warning: the simulation is run for\n            `prod(checkpoint_lengths)` timesteps, and the result is posthoc truncated\n            to the desired simulation length. Therefore, a poor choice of\n            `checkpoint_lengths` can lead to longer simulation time. If `None`, no\n            checkpointing is applied.\n        all_states: An optional initial state that was returned by a previous\n            `jx.integrate(..., return_states=True)` run. Overrides potentially\n            trainable initial states.\n        return_states: If True, it returns all states such that the current state of\n            the `Module` can be set with `set_states`.\n    \"\"\"\n\n    assert module.initialized, \"Module is not initialized, run `._initialize()`.\"\n    module.to_jax()  # Creates `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.\n\n    # Initialize the external inputs and their indices.\n    externals = module.externals.copy()\n    external_inds = module.external_inds.copy()\n\n    # If stimulus is inserted, add it to the external inputs.\n    externals, external_inds = add_stimuli(externals, external_inds, data_stimuli)\n\n    # If a clamp is inserted, add it to the external inputs.\n    externals, external_inds = add_clamps(externals, external_inds, data_clamps)\n\n    if not externals.keys():\n        # No stimulus was inserted and no clamp was set.\n        assert (\n            t_max is not None\n        ), \"If no stimulus or clamp are inserted you have to specify the simulation duration at `jx.integrate(..., t_max=)`.\"\n\n    for key in externals.keys():\n        externals[key] = externals[key].T  # Shape `(time, num_stimuli)`.\n\n    if module.recordings.empty:\n        raise ValueError(\"No recordings are set. Please set them.\")\n    rec_inds = module.recordings.rec_index.to_numpy()\n    rec_states = module.recordings.state.to_numpy()\n\n    # Shorten or pad stimulus depending on `t_max`.\n    if t_max is not None:\n        t_max_steps = int(t_max // delta_t + 1)\n\n        # Pad or truncate the stimulus.\n        for key in externals.keys():\n            if t_max_steps > externals[key].shape[0]:\n                if key == \"i\":\n                    pad = jnp.zeros(\n                        (t_max_steps - externals[\"i\"].shape[0], externals[\"i\"].shape[1])\n                    )\n                    externals[\"i\"] = jnp.concatenate((externals[\"i\"], pad))\n                else:\n                    raise NotImplementedError(\n                        \"clamp must be at least as long as simulation.\"\n                    )\n            else:\n                externals[key] = externals[key][:t_max_steps, :]\n\n    init_fn, step_fn = build_init_and_step_fn(\n        module, voltage_solver=voltage_solver, solver=solver\n    )\n    all_states, all_params = init_fn(params, all_states, param_state, delta_t)\n\n    def _body_fun(state, externals):\n        state = step_fn(state, all_params, externals, external_inds, delta_t)\n        recs = jnp.asarray(\n            [\n                state[rec_state][rec_ind]\n                for rec_state, rec_ind in zip(rec_states, rec_inds)\n            ]\n        )\n        return state, recs\n\n    # If necessary, pad the stimulus with zeros in order to simulate sufficiently long.\n    # The total simulation length will be `prod(checkpoint_lengths)`. At the end, we\n    # return only the first `nsteps_to_return` elements (plus the initial state).\n    if externals:\n        example_key = list(externals.keys())[0]\n        nsteps_to_return = len(externals[example_key])\n    else:\n        nsteps_to_return = t_max_steps\n\n    if checkpoint_lengths is None:\n        checkpoint_lengths = [nsteps_to_return]\n        length = nsteps_to_return\n    else:\n        length = prod(checkpoint_lengths)\n        size_difference = length - nsteps_to_return\n        assert (\n            nsteps_to_return <= length\n        ), \"The desired simulation duration is longer than `prod(nested_length)`.\"\n        if externals:\n            dummy_external = jnp.zeros(\n                (size_difference, externals[example_key].shape[1])\n            )\n            for key in externals.keys():\n                externals[key] = jnp.concatenate([externals[key], dummy_external])\n\n    # Record the initial state.\n    init_recs = jnp.asarray(\n        [\n            all_states[rec_state][rec_ind]\n            for rec_state, rec_ind in zip(rec_states, rec_inds)\n        ]\n    )\n    init_recording = jnp.expand_dims(init_recs, axis=0)\n\n    # Run simulation.\n    all_states, recordings = nested_checkpoint_scan(\n        _body_fun,\n        all_states,\n        externals,\n        length=length,\n        nested_lengths=checkpoint_lengths,\n    )\n    recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T\n    return (recs, all_states) if return_states else recs\n
"},{"location":"reference/integration/#jaxley.solver_gate.exponential_euler","title":"exponential_euler(x, dt, x_inf, x_tau)","text":"

An exact solver for the linear dynamical system dx = -(x - x_inf) / x_tau.

Source code in jaxley/solver_gate.py
def exponential_euler(\n    x: jnp.ndarray,\n    dt: float,\n    x_inf: jnp.ndarray,\n    x_tau: jnp.ndarray,\n):\n    \"\"\"An exact solver for the linear dynamical system `dx = -(x - x_inf) / x_tau`.\"\"\"\n    exp_term = save_exp(-dt / x_tau)\n    return x * exp_term + x_inf * (1.0 - exp_term)\n
"},{"location":"reference/integration/#jaxley.solver_gate.save_exp","title":"save_exp(x, max_value=20.0)","text":"

Clip the input to a maximum value and return its exponential.

Source code in jaxley/solver_gate.py
def save_exp(x, max_value: float = 20.0):\n    \"\"\"Clip the input to a maximum value and return its exponential.\"\"\"\n    x = jnp.clip(x, a_max=max_value)\n    return jnp.exp(x)\n
"},{"location":"reference/integration/#jaxley.solver_gate.solve_inf_gate_exponential","title":"solve_inf_gate_exponential(x, dt, s_inf, tau_s)","text":"

solves dx/dt = (s_inf - x) / tau_s via exponential Euler

Parameters:

Name Type Description Default x ndarray

gate variable

required dt float

time_delta

required s_inf ndarray

description

required tau_s ndarray

description

required

Returns:

Name Type Description _type_

updated gate

Source code in jaxley/solver_gate.py
def solve_inf_gate_exponential(\n    x: jnp.ndarray,\n    dt: float,\n    s_inf: jnp.ndarray,\n    tau_s: jnp.ndarray,\n):\n    \"\"\"solves dx/dt = (s_inf - x) / tau_s\n    via exponential Euler\n\n    Args:\n        x (jnp.ndarray): gate variable\n        dt (float): time_delta\n        s_inf (jnp.ndarray): _description_\n        tau_s (jnp.ndarray): _description_\n\n    Returns:\n        _type_: updated gate\n    \"\"\"\n    slope = -1.0 / tau_s\n    exp_term = save_exp(slope * dt)\n    return x * exp_term + s_inf * (1.0 - exp_term)\n
"},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_explicit","title":"step_voltage_explicit(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, ncomp_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)","text":"

Solve one timestep of branched nerve equations with explicit (forward) Euler.

Source code in jaxley/solver_voltage.py
def step_voltage_explicit(\n    voltages: jnp.ndarray,\n    voltage_terms: jnp.ndarray,\n    constant_terms: jnp.ndarray,\n    axial_conductances: jnp.ndarray,\n    internal_node_inds: jnp.ndarray,\n    sinks: jnp.ndarray,\n    sources: jnp.ndarray,\n    types: jnp.ndarray,\n    ncomp_per_branch: jnp.ndarray,\n    par_inds: jnp.ndarray,\n    child_inds: jnp.ndarray,\n    nbranches: int,\n    solver: str,\n    delta_t: float,\n    idx: JaxleySolveIndexer,\n    debug_states,\n) -> jnp.ndarray:\n    \"\"\"Solve one timestep of branched nerve equations with explicit (forward) Euler.\"\"\"\n    voltages = jnp.reshape(voltages, (nbranches, -1))\n    voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))\n    constant_terms = jnp.reshape(constant_terms, (nbranches, -1))\n\n    update = _voltage_vectorfield(\n        voltages,\n        voltage_terms,\n        constant_terms,\n        types,\n        sources,\n        sinks,\n        axial_conductances,\n        par_inds,\n        child_inds,\n        nbranches,\n        solver,\n        delta_t,\n        idx,\n        debug_states,\n    )\n    new_voltates = voltages + delta_t * update\n    return new_voltates.ravel(order=\"C\")\n
"},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_implicit_with_jaxley_spsolve","title":"step_voltage_implicit_with_jaxley_spsolve(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, ncomp_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)","text":"

Solve one timestep of branched nerve equations with implicit (backward) Euler.

Source code in jaxley/solver_voltage.py
def step_voltage_implicit_with_jaxley_spsolve(\n    voltages: jnp.ndarray,\n    voltage_terms: jnp.ndarray,\n    constant_terms: jnp.ndarray,\n    axial_conductances: jnp.ndarray,\n    internal_node_inds: jnp.ndarray,\n    sinks: jnp.ndarray,\n    sources: jnp.ndarray,\n    types: jnp.ndarray,\n    ncomp_per_branch: jnp.ndarray,\n    par_inds: jnp.ndarray,\n    child_inds: jnp.ndarray,\n    nbranches: int,\n    solver: str,\n    delta_t: float,\n    idx: JaxleySolveIndexer,\n    debug_states,\n):\n    \"\"\"Solve one timestep of branched nerve equations with implicit (backward) Euler.\"\"\"\n    # Build diagonals.\n    c2c = np.isin(types, [0, 1, 2])\n    total_ncomp = idx.cumsum_ncomp[-1]\n    diags = jnp.ones(total_ncomp)\n\n    # if-case needed because `.at` does not allow empty inputs, but the input is\n    # empty for compartments.\n    if len(sinks[c2c]) > 0:\n        diags = diags.at[idx.mask(sinks[c2c])].add(delta_t * axial_conductances[c2c])\n\n    diags = diags.at[idx.mask(internal_node_inds)].add(delta_t * voltage_terms)\n\n    # Build solves.\n    solves = jnp.zeros(total_ncomp)\n    solves = solves.at[idx.mask(internal_node_inds)].add(\n        voltages + delta_t * constant_terms\n    )\n\n    # Build upper and lower within the branch.\n    c2c = types == 0  # c2c = compartment-to-compartment.\n\n    # Build uppers.\n    uppers = jnp.zeros(total_ncomp)\n    upper_inds = sources[c2c] > sinks[c2c]\n    sinks_upper = sinks[c2c][upper_inds]\n    if len(sinks_upper) > 0:\n        uppers = uppers.at[idx.mask(sinks_upper)].add(\n            -delta_t * axial_conductances[c2c][upper_inds]\n        )\n\n    # Build lowers.\n    lowers = jnp.zeros(total_ncomp)\n    lower_inds = sources[c2c] < sinks[c2c]\n    sinks_lower = sinks[c2c][lower_inds]\n    if len(sinks_lower) > 0:\n        lowers = lowers.at[idx.mask(sinks_lower)].add(\n            -delta_t * axial_conductances[c2c][lower_inds]\n        )\n\n    # Build branchpoint conductances.\n    branchpoint_conds_parents = axial_conductances[types == 1]\n    branchpoint_conds_children = axial_conductances[types == 2]\n    branchpoint_weights_parents = axial_conductances[types == 3]\n    branchpoint_weights_children = axial_conductances[types == 4]\n    all_branchpoint_vals = jnp.concatenate(\n        [branchpoint_weights_parents, branchpoint_weights_children]\n    )\n    # Find unique group identifiers\n    num_branchpoints = len(branchpoint_conds_parents)\n    branchpoint_diags = -group_and_sum(\n        all_branchpoint_vals, idx.branchpoint_group_inds, num_branchpoints\n    )\n    branchpoint_solves = jnp.zeros((num_branchpoints,))\n\n    branchpoint_conds_children = -delta_t * branchpoint_conds_children\n    branchpoint_conds_parents = -delta_t * branchpoint_conds_parents\n\n    # Here, I move all child and parent indices towards a branchpoint into a larger\n    # vector. This is wasteful, but it makes indexing much easier. JIT compiling\n    # makes the speed difference negligible.\n    # Children.\n    bp_conds_children = jnp.zeros(nbranches)\n    bp_weights_children = jnp.zeros(nbranches)\n    # Parents.\n    bp_conds_parents = jnp.zeros(nbranches)\n    bp_weights_parents = jnp.zeros(nbranches)\n\n    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n    # `len(inds) == 0` is the case for branches and compartments.\n    if num_branchpoints > 0:\n        bp_conds_children = bp_conds_children.at[child_inds].set(\n            branchpoint_conds_children\n        )\n        bp_weights_children = bp_weights_children.at[child_inds].set(\n            branchpoint_weights_children\n        )\n        bp_conds_parents = bp_conds_parents.at[par_inds].set(branchpoint_conds_parents)\n        bp_weights_parents = bp_weights_parents.at[par_inds].set(\n            branchpoint_weights_parents\n        )\n\n    # Triangulate the linear system of equations.\n    (\n        diags,\n        lowers,\n        solves,\n        uppers,\n        branchpoint_diags,\n        branchpoint_solves,\n        bp_weights_children,\n        bp_conds_parents,\n    ) = _triang_branched(\n        lowers,\n        diags,\n        uppers,\n        solves,\n        bp_conds_children,\n        bp_conds_parents,\n        bp_weights_children,\n        bp_weights_parents,\n        branchpoint_diags,\n        branchpoint_solves,\n        solver,\n        ncomp_per_branch,\n        idx,\n        debug_states,\n    )\n\n    # Backsubstitute the linear system of equations.\n    (\n        solves,\n        lowers,\n        diags,\n        bp_weights_parents,\n        branchpoint_solves,\n        bp_conds_children,\n    ) = _backsub_branched(\n        lowers,\n        diags,\n        uppers,\n        solves,\n        bp_conds_children,\n        bp_conds_parents,\n        bp_weights_children,\n        bp_weights_parents,\n        branchpoint_diags,\n        branchpoint_solves,\n        solver,\n        ncomp_per_branch,\n        idx,\n        debug_states,\n    )\n    return solves.ravel(order=\"C\")[idx.mask(internal_node_inds)]\n
"},{"location":"reference/mechanisms/","title":"Channels","text":""},{"location":"reference/mechanisms/#channel","title":"Channel","text":"

Channel base class. All channels inherit from this class.

As in NEURON, a Channel is considered a distributed process, which means that its conductances are to be specified in S/cm2 and its currents are to be specified in uA/cm2.

Source code in jaxley/channels/channel.py
class Channel:\n    \"\"\"Channel base class. All channels inherit from this class.\n\n    As in NEURON, a `Channel` is considered a distributed process, which means that its\n    conductances are to be specified in `S/cm2` and its currents are to be specified in\n    `uA/cm2`.\"\"\"\n\n    _name = None\n    channel_params = None\n    channel_states = None\n    current_name = None\n\n    def __init__(self, name: Optional[str] = None):\n        contact = (\n            \"If you have any questions, please reach out via email to \"\n            \"michael.deistler@uni-tuebingen.de or create an issue on Github: \"\n            \"https://github.com/jaxleyverse/jaxley/issues. Thank you!\"\n        )\n        if (\n            not hasattr(self, \"current_is_in_mA_per_cm2\")\n            or not self.current_is_in_mA_per_cm2\n        ):\n            raise ValueError(\n                \"The channel you are using is deprecated. \"\n                \"In Jaxley version 0.5.0, we changed the unit of the current returned \"\n                \"by `compute_current` of channels from `uA/cm^2` to `mA/cm^2`. Please \"\n                \"update your channel model (by dividing the resulting current by 1000) \"\n                \"and set `self.current_is_in_mA_per_cm2=True` as the first line \"\n                f\"in the `__init__()` method of your channel. {contact}\"\n            )\n\n        self._name = name if name else self.__class__.__name__\n\n    @property\n    def name(self) -> Optional[str]:\n        \"\"\"The name of the channel (by default, this is the class name).\"\"\"\n        return self._name\n\n    def change_name(self, new_name: str):\n        \"\"\"Change the channel name.\n\n        Args:\n            new_name: The new name of the channel.\n\n        Returns:\n            Renamed channel, such that this function is chainable.\n        \"\"\"\n        old_prefix = self._name + \"_\"\n        new_prefix = new_name + \"_\"\n\n        self._name = new_name\n        self.channel_params = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.channel_params.items()\n        }\n\n        self.channel_states = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.channel_states.items()\n        }\n        return self\n\n    def update_states(\n        self, states, dt, v, params\n    ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Return the updated states.\"\"\"\n        raise NotImplementedError\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Given channel states and voltage, return the current through the channel.\n\n        Args:\n            states: All states of the compartment.\n            v: Voltage of the compartment in mV.\n            params: Parameters of the channel (conductances in `S/cm2`).\n\n        Returns:\n            Current in `uA/cm2`.\n        \"\"\"\n        raise NotImplementedError\n\n    def init_state(\n        self,\n        states: Dict[str, jnp.ndarray],\n        v: jnp.ndarray,\n        params: Dict[str, jnp.ndarray],\n        delta_t: float,\n    ):\n        \"\"\"Initialize states of channel.\"\"\"\n        return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.name","title":"name: Optional[str] property","text":"

The name of the channel (by default, this is the class name).

"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.change_name","title":"change_name(new_name)","text":"

Change the channel name.

Parameters:

Name Type Description Default new_name str

The new name of the channel.

required

Returns:

Type Description

Renamed channel, such that this function is chainable.

Source code in jaxley/channels/channel.py
def change_name(self, new_name: str):\n    \"\"\"Change the channel name.\n\n    Args:\n        new_name: The new name of the channel.\n\n    Returns:\n        Renamed channel, such that this function is chainable.\n    \"\"\"\n    old_prefix = self._name + \"_\"\n    new_prefix = new_name + \"_\"\n\n    self._name = new_name\n    self.channel_params = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.channel_params.items()\n    }\n\n    self.channel_states = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.channel_states.items()\n    }\n    return self\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.compute_current","title":"compute_current(states, v, params)","text":"

Given channel states and voltage, return the current through the channel.

Parameters:

Name Type Description Default states Dict[str, ndarray]

All states of the compartment.

required v

Voltage of the compartment in mV.

required params Dict[str, ndarray]

Parameters of the channel (conductances in S/cm2).

required

Returns:

Type Description

Current in uA/cm2.

Source code in jaxley/channels/channel.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Given channel states and voltage, return the current through the channel.\n\n    Args:\n        states: All states of the compartment.\n        v: Voltage of the compartment in mV.\n        params: Parameters of the channel (conductances in `S/cm2`).\n\n    Returns:\n        Current in `uA/cm2`.\n    \"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize states of channel.

Source code in jaxley/channels/channel.py
def init_state(\n    self,\n    states: Dict[str, jnp.ndarray],\n    v: jnp.ndarray,\n    params: Dict[str, jnp.ndarray],\n    delta_t: float,\n):\n    \"\"\"Initialize states of channel.\"\"\"\n    return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.update_states","title":"update_states(states, dt, v, params)","text":"

Return the updated states.

Source code in jaxley/channels/channel.py
def update_states(\n    self, states, dt, v, params\n) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n    \"\"\"Return the updated states.\"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/mechanisms/#hh","title":"HH","text":"

Bases: Channel

Hodgkin-Huxley channel.

Source code in jaxley/channels/hh.py
class HH(Channel):\n    \"\"\"Hodgkin-Huxley channel.\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gNa\": 0.12,\n            f\"{prefix}_gK\": 0.036,\n            f\"{prefix}_gLeak\": 0.0003,\n            f\"{prefix}_eNa\": 50.0,\n            f\"{prefix}_eK\": -77.0,\n            f\"{prefix}_eLeak\": -54.3,\n        }\n        self.channel_states = {\n            f\"{prefix}_m\": 0.2,\n            f\"{prefix}_h\": 0.2,\n            f\"{prefix}_n\": 0.2,\n        }\n        self.current_name = f\"i_HH\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Return updated HH channel state.\"\"\"\n        prefix = self._name\n        m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n        new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n        new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n        new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n        return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current through HH channels.\"\"\"\n        prefix = self._name\n        m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n        gNa = params[f\"{prefix}_gNa\"] * (m**3) * h  # S/cm^2\n        gK = params[f\"{prefix}_gK\"] * n**4  # S/cm^2\n        gLeak = params[f\"{prefix}_gLeak\"]  # S/cm^2\n\n        return (\n            gNa * (v - params[f\"{prefix}_eNa\"])\n            + gK * (v - params[f\"{prefix}_eK\"])\n            + gLeak * (v - params[f\"{prefix}_eLeak\"])\n        )\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_m, beta_m = self.m_gate(v)\n        alpha_h, beta_h = self.h_gate(v)\n        alpha_n, beta_n = self.n_gate(v)\n        return {\n            f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n            f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n            f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n        }\n\n    @staticmethod\n    def m_gate(v):\n        alpha = 0.1 * _vtrap(-(v + 40), 10)\n        beta = 4.0 * save_exp(-(v + 65) / 18)\n        return alpha, beta\n\n    @staticmethod\n    def h_gate(v):\n        alpha = 0.07 * save_exp(-(v + 65) / 20)\n        beta = 1.0 / (save_exp(-(v + 35) / 10) + 1)\n        return alpha, beta\n\n    @staticmethod\n    def n_gate(v):\n        alpha = 0.01 * _vtrap(-(v + 55), 10)\n        beta = 0.125 * save_exp(-(v + 65) / 80)\n        return alpha, beta\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.compute_current","title":"compute_current(states, v, params)","text":"

Return current through HH channels.

Source code in jaxley/channels/hh.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current through HH channels.\"\"\"\n    prefix = self._name\n    m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n    gNa = params[f\"{prefix}_gNa\"] * (m**3) * h  # S/cm^2\n    gK = params[f\"{prefix}_gK\"] * n**4  # S/cm^2\n    gLeak = params[f\"{prefix}_gLeak\"]  # S/cm^2\n\n    return (\n        gNa * (v - params[f\"{prefix}_eNa\"])\n        + gK * (v - params[f\"{prefix}_eK\"])\n        + gLeak * (v - params[f\"{prefix}_eLeak\"])\n    )\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/hh.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_m, beta_m = self.m_gate(v)\n    alpha_h, beta_h = self.h_gate(v)\n    alpha_n, beta_n = self.n_gate(v)\n    return {\n        f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n        f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n        f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n    }\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.update_states","title":"update_states(states, dt, v, params)","text":"

Return updated HH channel state.

Source code in jaxley/channels/hh.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Return updated HH channel state.\"\"\"\n    prefix = self._name\n    m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n    new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n    new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n    new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n    return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n
"},{"location":"reference/mechanisms/#pospischil","title":"Pospischil","text":"

Bases: Channel

Leak current

Source code in jaxley/channels/pospischil.py
class Leak(Channel):\n    \"\"\"Leak current\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gLeak\": 1e-4,\n            f\"{prefix}_eLeak\": -70.0,\n        }\n        self.channel_states = {}\n        self.current_name = f\"i_{prefix}\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"No state to update.\"\"\"\n        return {}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        gLeak = params[f\"{prefix}_gLeak\"]  # S/cm^2\n        return gLeak * (v - params[f\"{prefix}_eLeak\"])\n\n    def init_state(self, states, v, params, delta_t):\n        return {}\n

Bases: Channel

Sodium channel

Source code in jaxley/channels/pospischil.py
class Na(Channel):\n    \"\"\"Sodium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gNa\": 50e-3,\n            \"eNa\": 50.0,\n            \"vt\": -60.0,  # Global parameter, not prefixed with `Na`.\n        }\n        self.channel_states = {f\"{prefix}_m\": 0.2, f\"{prefix}_h\": 0.2}\n        self.current_name = f\"i_Na\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n        new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n        new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n        return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n        gNa = params[f\"{prefix}_gNa\"] * (m**3) * h  # S/cm^2\n\n        current = gNa * (v - params[\"eNa\"])\n        return current\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n        alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n        return {\n            f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n            f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n        }\n\n    @staticmethod\n    def m_gate(v, vt):\n        v_alpha = v - vt - 13.0\n        alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25\n\n        v_beta = v - vt - 40.0\n        beta = 0.28 * efun(0.2 * v_beta) / 0.2\n        return alpha, beta\n\n    @staticmethod\n    def h_gate(v, vt):\n        v_alpha = v - vt - 17.0\n        alpha = 0.128 * save_exp(-v_alpha / 18.0)\n\n        v_beta = v - vt - 40.0\n        beta = 4.0 / (save_exp(-v_beta / 5.0) + 1.0)\n        return alpha, beta\n

Bases: Channel

Potassium channel

Source code in jaxley/channels/pospischil.py
class K(Channel):\n    \"\"\"Potassium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gK\": 5e-3,\n            \"eK\": -90.0,\n            \"vt\": -60.0,  # Global parameter, not prefixed with `Na`.\n        }\n        self.channel_states = {f\"{prefix}_n\": 0.2}\n        self.current_name = f\"i_K\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        n = states[f\"{prefix}_n\"]\n        new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n        return {f\"{prefix}_n\": new_n}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        n = states[f\"{prefix}_n\"]\n\n        gK = params[f\"{prefix}_gK\"] * (n**4)  # S/cm^2\n\n        return gK * (v - params[\"eK\"])\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n        return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n\n    @staticmethod\n    def n_gate(v, vt):\n        v_alpha = v - vt - 15.0\n        alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2\n\n        v_beta = v - vt - 10.0\n        beta = 0.5 * save_exp(-v_beta / 40.0)\n        return alpha, beta\n

Bases: Channel

Slow M Potassium channel

Source code in jaxley/channels/pospischil.py
class Km(Channel):\n    \"\"\"Slow M Potassium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gKm\": 0.004e-3,\n            f\"{prefix}_taumax\": 4000.0,\n            f\"eK\": -90.0,\n        }\n        self.channel_states = {f\"{prefix}_p\": 0.2}\n        self.current_name = f\"i_K\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        p = states[f\"{prefix}_p\"]\n        new_p = solve_inf_gate_exponential(\n            p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n        )\n        return {f\"{prefix}_p\": new_p}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        p = states[f\"{prefix}_p\"]\n\n        gKm = params[f\"{prefix}_gKm\"] * p  # S/cm^2\n        return gKm * (v - params[\"eK\"])\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n        return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n\n    @staticmethod\n    def p_gate(v, taumax):\n        v_p = v + 35.0\n        p_inf = 1.0 / (1.0 + save_exp(-0.1 * v_p))\n\n        tau_p = taumax / (3.3 * save_exp(0.05 * v_p) + save_exp(-0.05 * v_p))\n\n        return p_inf, tau_p\n

Bases: Channel

L-type Calcium channel

Source code in jaxley/channels/pospischil.py
class CaL(Channel):\n    \"\"\"L-type Calcium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gCaL\": 0.1e-3,\n            \"eCa\": 120.0,\n        }\n        self.channel_states = {f\"{prefix}_q\": 0.2, f\"{prefix}_r\": 0.2}\n        self.current_name = f\"i_Ca\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n        new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n        new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n        return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n        gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r  # S/cm^2\n\n        return gCaL * (v - params[\"eCa\"])\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_q, beta_q = self.q_gate(v)\n        alpha_r, beta_r = self.r_gate(v)\n        return {\n            f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n            f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n        }\n\n    @staticmethod\n    def q_gate(v):\n        v_alpha = -v - 27.0\n        alpha = 0.055 * efun(v_alpha / 3.8) * 3.8\n\n        v_beta = -v - 75.0\n        beta = 0.94 * save_exp(v_beta / 17.0)\n        return alpha, beta\n\n    @staticmethod\n    def r_gate(v):\n        v_alpha = -v - 13.0\n        alpha = 0.000457 * save_exp(v_alpha / 50)\n\n        v_beta = -v - 15.0\n        beta = 0.0065 / (save_exp(v_beta / 28.0) + 1)\n        return alpha, beta\n

Bases: Channel

T-type Calcium channel

Source code in jaxley/channels/pospischil.py
class CaT(Channel):\n    \"\"\"T-type Calcium channel\"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        self.current_is_in_mA_per_cm2 = True\n\n        super().__init__(name)\n        prefix = self._name\n        self.channel_params = {\n            f\"{prefix}_gCaT\": 0.4e-4,\n            f\"{prefix}_vx\": 2.0,\n            \"eCa\": 120.0,  # Global parameter, not prefixed with `CaT`.\n        }\n        self.channel_states = {f\"{prefix}_u\": 0.2}\n        self.current_name = f\"i_Ca\"\n\n    def update_states(\n        self,\n        states: Dict[str, jnp.ndarray],\n        dt,\n        v,\n        params: Dict[str, jnp.ndarray],\n    ):\n        \"\"\"Update state.\"\"\"\n        prefix = self._name\n        u = states[f\"{prefix}_u\"]\n        new_u = solve_inf_gate_exponential(\n            u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n        )\n        return {f\"{prefix}_u\": new_u}\n\n    def compute_current(\n        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n    ):\n        \"\"\"Return current.\"\"\"\n        prefix = self._name\n        u = states[f\"{prefix}_u\"]\n        s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n        gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u  # S/cm^2\n\n        return gCaT * (v - params[\"eCa\"])\n\n    def init_state(self, states, v, params, delta_t):\n        \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n        prefix = self._name\n        alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n        return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n\n    @staticmethod\n    def u_gate(v, vx):\n        v_u1 = v + vx + 81.0\n        u_inf = 1.0 / (1.0 + save_exp(v_u1 / 4))\n\n        tau_u = (30.8 + (211.4 + save_exp((v + vx + 113.2) / 5.0))) / (\n            3.7 * (1 + save_exp((v + vx + 84.0) / 3.2))\n        )\n\n        return u_inf, tau_u\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    gLeak = params[f\"{prefix}_gLeak\"]  # S/cm^2\n    return gLeak * (v - params[f\"{prefix}_eLeak\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.update_states","title":"update_states(states, dt, v, params)","text":"

No state to update.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"No state to update.\"\"\"\n    return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n    gNa = params[f\"{prefix}_gNa\"] * (m**3) * h  # S/cm^2\n\n    current = gNa * (v - params[\"eNa\"])\n    return current\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n    alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n    return {\n        f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n        f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n    }\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.update_states","title":"update_states(states, dt, v, params)","text":"

Update state.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n    new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n    new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n    return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    n = states[f\"{prefix}_n\"]\n\n    gK = params[f\"{prefix}_gK\"] * (n**4)  # S/cm^2\n\n    return gK * (v - params[\"eK\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n    return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.update_states","title":"update_states(states, dt, v, params)","text":"

Update state.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    n = states[f\"{prefix}_n\"]\n    new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n    return {f\"{prefix}_n\": new_n}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    p = states[f\"{prefix}_p\"]\n\n    gKm = params[f\"{prefix}_gKm\"] * p  # S/cm^2\n    return gKm * (v - params[\"eK\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n    return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.update_states","title":"update_states(states, dt, v, params)","text":"

Update state.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    p = states[f\"{prefix}_p\"]\n    new_p = solve_inf_gate_exponential(\n        p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n    )\n    return {f\"{prefix}_p\": new_p}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n    gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r  # S/cm^2\n\n    return gCaL * (v - params[\"eCa\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_q, beta_q = self.q_gate(v)\n    alpha_r, beta_r = self.r_gate(v)\n    return {\n        f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n        f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n    }\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.update_states","title":"update_states(states, dt, v, params)","text":"

Update state.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n    new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n    new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n    return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.compute_current","title":"compute_current(states, v, params)","text":"

Return current.

Source code in jaxley/channels/pospischil.py
def compute_current(\n    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n    \"\"\"Return current.\"\"\"\n    prefix = self._name\n    u = states[f\"{prefix}_u\"]\n    s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n    gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u  # S/cm^2\n\n    return gCaT * (v - params[\"eCa\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.init_state","title":"init_state(states, v, params, delta_t)","text":"

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n    \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n    prefix = self._name\n    alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n    return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.update_states","title":"update_states(states, dt, v, params)","text":"

Update state.

Source code in jaxley/channels/pospischil.py
def update_states(\n    self,\n    states: Dict[str, jnp.ndarray],\n    dt,\n    v,\n    params: Dict[str, jnp.ndarray],\n):\n    \"\"\"Update state.\"\"\"\n    prefix = self._name\n    u = states[f\"{prefix}_u\"]\n    new_u = solve_inf_gate_exponential(\n        u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n    )\n    return {f\"{prefix}_u\": new_u}\n
"},{"location":"reference/mechanisms/#synapses","title":"Synapses","text":""},{"location":"reference/mechanisms/#synapse","title":"Synapse","text":"

Base class for a synapse.

As in NEURON, a Synapse is considered a point process, which means that its conductances are to be specified in uS and its currents are to be specified in nA.

Source code in jaxley/synapses/synapse.py
class Synapse:\n    \"\"\"Base class for a synapse.\n\n    As in NEURON, a `Synapse` is considered a point process, which means that its\n    conductances are to be specified in `uS` and its currents are to be specified in\n    `nA`.\n    \"\"\"\n\n    _name = None\n    synapse_params = None\n    synapse_states = None\n\n    def __init__(self, name: Optional[str] = None):\n        self._name = name if name else self.__class__.__name__\n\n    @property\n    def name(self) -> Optional[str]:\n        return self._name\n\n    def change_name(self, new_name: str):\n        \"\"\"Change the synapse name.\n\n        Args:\n            new_name: The new name of the channel.\n\n        Returns:\n            Renamed channel, such that this function is chainable.\n        \"\"\"\n        old_prefix = self._name + \"_\"\n        new_prefix = new_name + \"_\"\n\n        self._name = new_name\n        self.synapse_params = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.synapse_params.items()\n        }\n\n        self.synapse_states = {\n            (\n                new_prefix + key[len(old_prefix) :]\n                if key.startswith(old_prefix)\n                else key\n            ): value\n            for key, value in self.synapse_states.items()\n        }\n        return self\n\n    def update_states(\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        pre_voltage: jnp.ndarray,\n        post_voltage: jnp.ndarray,\n        params: Dict[str, jnp.ndarray],\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"ODE update step.\n\n        Args:\n            states: States of the synapse.\n            delta_t: Time step in `ms`.\n            pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n            post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n            params: Parameters of the synapse. Conductances in `uS`.\n\n        Returns:\n            Updated states.\"\"\"\n        raise NotImplementedError\n\n    def compute_current(\n        states: Dict[str, jnp.ndarray],\n        pre_voltage: jnp.ndarray,\n        post_voltage: jnp.ndarray,\n        params: Dict[str, jnp.ndarray],\n    ) -> jnp.ndarray:\n        \"\"\"Return current through one synapse in `nA`.\n\n        Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n        Args:\n            states: States of the synapse.\n            pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n            post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n            params: Parameters of the synapse. Conductances in `uS`.\n\n        Returns:\n            Current through the synapse in `nA`, shape `()`.\n        \"\"\"\n        raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.change_name","title":"change_name(new_name)","text":"

Change the synapse name.

Parameters:

Name Type Description Default new_name str

The new name of the channel.

required

Returns:

Type Description

Renamed channel, such that this function is chainable.

Source code in jaxley/synapses/synapse.py
def change_name(self, new_name: str):\n    \"\"\"Change the synapse name.\n\n    Args:\n        new_name: The new name of the channel.\n\n    Returns:\n        Renamed channel, such that this function is chainable.\n    \"\"\"\n    old_prefix = self._name + \"_\"\n    new_prefix = new_name + \"_\"\n\n    self._name = new_name\n    self.synapse_params = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.synapse_params.items()\n    }\n\n    self.synapse_states = {\n        (\n            new_prefix + key[len(old_prefix) :]\n            if key.startswith(old_prefix)\n            else key\n        ): value\n        for key, value in self.synapse_states.items()\n    }\n    return self\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)","text":"

Return current through one synapse in nA.

Internally, we use jax.vmap to vectorize this function across many synapses.

Parameters:

Name Type Description Default states Dict[str, ndarray]

States of the synapse.

required pre_voltage ndarray

Voltage of the presynaptic compartment, shape ().

required post_voltage ndarray

Voltage of the postsynaptic compartment, shape ().

required params Dict[str, ndarray]

Parameters of the synapse. Conductances in uS.

required

Returns:

Type Description ndarray

Current through the synapse in nA, shape ().

Source code in jaxley/synapses/synapse.py
def compute_current(\n    states: Dict[str, jnp.ndarray],\n    pre_voltage: jnp.ndarray,\n    post_voltage: jnp.ndarray,\n    params: Dict[str, jnp.ndarray],\n) -> jnp.ndarray:\n    \"\"\"Return current through one synapse in `nA`.\n\n    Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n    Args:\n        states: States of the synapse.\n        pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n        post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n        params: Parameters of the synapse. Conductances in `uS`.\n\n    Returns:\n        Current through the synapse in `nA`, shape `()`.\n    \"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

ODE update step.

Parameters:

Name Type Description Default states Dict[str, ndarray]

States of the synapse.

required delta_t float

Time step in ms.

required pre_voltage ndarray

Voltage of the presynaptic compartment, shape ().

required post_voltage ndarray

Voltage of the postsynaptic compartment, shape ().

required params Dict[str, ndarray]

Parameters of the synapse. Conductances in uS.

required

Returns:

Type Description Dict[str, ndarray]

Updated states.

Source code in jaxley/synapses/synapse.py
def update_states(\n    states: Dict[str, jnp.ndarray],\n    delta_t: float,\n    pre_voltage: jnp.ndarray,\n    post_voltage: jnp.ndarray,\n    params: Dict[str, jnp.ndarray],\n) -> Dict[str, jnp.ndarray]:\n    \"\"\"ODE update step.\n\n    Args:\n        states: States of the synapse.\n        delta_t: Time step in `ms`.\n        pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n        post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n        params: Parameters of the synapse. Conductances in `uS`.\n\n    Returns:\n        Updated states.\"\"\"\n    raise NotImplementedError\n
"},{"location":"reference/mechanisms/#ionotropic-synapse","title":"Ionotropic Synapse","text":"

Bases: Synapse

Compute synaptic current and update synapse state for a generic ionotropic synapse.

The synapse state \u201cs\u201d is the probability that a postsynaptic receptor channel is open, and this depends on the amount of neurotransmitter released, which is in turn dependent on the presynaptic voltage.

The synaptic parameters are
  • gS: the maximal conductance across the postsynaptic membrane (uS)
  • e_syn: the reversal potential across the postsynaptic membrane (mV)
  • k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic receptor (s^-1)
Details of this implementation can be found in the following book chapter

L. F. Abbott and E. Marder, \u201cModeling Small Networks,\u201d in Methods in Neuronal Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.

Source code in jaxley/synapses/ionotropic.py
class IonotropicSynapse(Synapse):\n    \"\"\"\n    Compute synaptic current and update synapse state for a generic ionotropic synapse.\n\n    The synapse state \"s\" is the probability that a postsynaptic receptor channel is\n    open, and this depends on the amount of neurotransmitter released, which is in turn\n    dependent on the presynaptic voltage.\n\n    The synaptic parameters are:\n        - gS: the maximal conductance across the postsynaptic membrane (uS)\n        - e_syn: the reversal potential across the postsynaptic membrane (mV)\n        - k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic\n            receptor (s^-1)\n\n    Details of this implementation can be found in the following book chapter:\n        L. F. Abbott and E. Marder, \"Modeling Small Networks,\" in Methods in Neuronal\n        Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.\n\n    \"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.synapse_params = {\n            f\"{prefix}_gS\": 1e-4,\n            f\"{prefix}_e_syn\": 0.0,\n            f\"{prefix}_k_minus\": 0.025,\n        }\n        self.synapse_states = {f\"{prefix}_s\": 0.2}\n\n    def update_states(\n        self,\n        states: Dict,\n        delta_t: float,\n        pre_voltage: float,\n        post_voltage: float,\n        params: Dict,\n    ) -> Dict:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        prefix = self._name\n        v_th = -35.0  # mV\n        delta = 10.0  # mV\n\n        s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n        tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n        slope = -1.0 / tau_s\n        exp_term = save_exp(slope * delta_t)\n        new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n        return {f\"{prefix}_s\": new_s}\n\n    def compute_current(\n        self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n    ) -> float:\n        prefix = self._name\n        g_syn = params[f\"{prefix}_gS\"] * states[f\"{prefix}_s\"]\n        return g_syn * (post_voltage - params[f\"{prefix}_e_syn\"])\n
"},{"location":"reference/mechanisms/#jaxley.synapses.ionotropic.IonotropicSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

Return updated synapse state and current.

Source code in jaxley/synapses/ionotropic.py
def update_states(\n    self,\n    states: Dict,\n    delta_t: float,\n    pre_voltage: float,\n    post_voltage: float,\n    params: Dict,\n) -> Dict:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    prefix = self._name\n    v_th = -35.0  # mV\n    delta = 10.0  # mV\n\n    s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n    tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n    slope = -1.0 / tau_s\n    exp_term = save_exp(slope * delta_t)\n    new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n    return {f\"{prefix}_s\": new_s}\n
"},{"location":"reference/mechanisms/#tanh-rate-synapse","title":"TanH Rate Synapse","text":"

Bases: Synapse

Compute synaptic current for tanh synapse (no state).

Source code in jaxley/synapses/tanh_rate.py
class TanhRateSynapse(Synapse):\n    \"\"\"\n    Compute synaptic current for tanh synapse (no state).\n    \"\"\"\n\n    def __init__(self, name: Optional[str] = None):\n        super().__init__(name)\n        prefix = self._name\n        self.synapse_params = {\n            f\"{prefix}_gS\": 1e-4,\n            f\"{prefix}_x_offset\": -70.0,\n            f\"{prefix}_slope\": 1.0,\n        }\n        self.synapse_states = {}\n\n    def update_states(\n        self,\n        states: Dict,\n        delta_t: float,\n        pre_voltage: float,\n        post_voltage: float,\n        params: Dict,\n    ) -> Dict:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        return {}\n\n    def compute_current(\n        self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n    ) -> float:\n        \"\"\"Return updated synapse state and current.\"\"\"\n        prefix = self._name\n        current = (\n            -1\n            * params[f\"{prefix}_gS\"]\n            * jnp.tanh(\n                (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n            )\n        )\n        return current\n
"},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)","text":"

Return updated synapse state and current.

Source code in jaxley/synapses/tanh_rate.py
def compute_current(\n    self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n) -> float:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    prefix = self._name\n    current = (\n        -1\n        * params[f\"{prefix}_gS\"]\n        * jnp.tanh(\n            (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n        )\n    )\n    return current\n
"},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)","text":"

Return updated synapse state and current.

Source code in jaxley/synapses/tanh_rate.py
def update_states(\n    self,\n    states: Dict,\n    delta_t: float,\n    pre_voltage: float,\n    post_voltage: float,\n    params: Dict,\n) -> Dict:\n    \"\"\"Return updated synapse state and current.\"\"\"\n    return {}\n
"},{"location":"reference/modules/","title":"Modules","text":""},{"location":"reference/modules/#module","title":"Module","text":"

Bases: ABC

Module base class.

Modules are everything that can be passed to jx.integrate, i.e. compartments, branches, cells, and networks.

This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks).

Modules can be traversed and modified using the at, cell, branch, comp, edge, and loc methods. The scope method can be used to toggle between global and local indices. Traversal of Modules will return a View of itself, that has a modified set of attributes, which only consider the part of the Module that is in view.

For developers: The above has consequences for how to operate on Module and which changes take affect where. The following guidelines should be followed (copied from View):

  1. We consider a Module to have everything in view.
  2. Views can display and keep track of how a module is traversed. But(!), do not support making changes or setting variables. This still has to be done in the base Module, i.e. self.base. In order to enssure that these changes only affects whatever is currently in view self._nodes_in_view, or self._edges_in_view among others have to be used. Operating on nodes currently in view can for example be done with self.base.node.loc[self._nodes_in_view].
  3. Every attribute of Module that changes based on what\u2019s in view, i.e. xyzr, needs to modified when View is instantiated. I.e. xyzr of cell.branch(0), should be [self.base.xyzr[0]] This could be achieved via: [self.base.xyzr[b] for b in self._branches_in_view].

For developers: If you want to add a new method to Module, here is an example of how to make methods of Module compatible with View:

.. code-block:: python

# Use data in view to return something.\ndef count_small_branches(self):\n    # no need to use self.base.attr + viewed indices,\n    # since no change is made to the attr in question (nodes)\n    comp_lens = self.nodes[\"length\"]\n    branch_lens = comp_lens.groupby(\"global_branch_index\").sum()\n    return np.sum(branch_lens < 10)\n\n# Change data in view.\ndef change_attr_in_view(self):\n    # changes to attrs have to be made via self.base.attr + viewed indices\n    a = func1(self.base.attr1[self._cells_in_view])\n    b = func2(self.base.attr2[self._edges_in_view])\n    self.base.attr3[self._branches_in_view] = a + b\n
Source code in jaxley/modules/base.py
class Module(ABC):\n    \"\"\"Module base class.\n\n    Modules are everything that can be passed to `jx.integrate`, i.e. compartments,\n    branches, cells, and networks.\n\n    This base class defines the scaffold for all jaxley modules (compartments,\n    branches, cells, networks).\n\n    Modules can be traversed and modified using the `at`, `cell`, `branch`, `comp`,\n    `edge`, and `loc` methods. The `scope` method can be used to toggle between\n    global and local indices. Traversal of Modules will return a `View` of itself,\n    that has a modified set of attributes, which only consider the part of the Module\n    that is in view.\n\n    For developers: The above has consequences for how to operate on `Module` and which\n    changes take affect where. The following guidelines should be followed (copied from\n    `View`):\n\n    1. We consider a Module to have everything in view.\n    2. Views can display and keep track of how a module is traversed. But(!),\n       do not support making changes or setting variables. This still has to be\n       done in the base Module, i.e. `self.base`. In order to enssure that these\n       changes only affects whatever is currently in view `self._nodes_in_view`,\n       or `self._edges_in_view` among others have to be used. Operating on nodes\n       currently in view can for example be done with\n       `self.base.node.loc[self._nodes_in_view]`.\n    3. Every attribute of Module that changes based on what's in view, i.e. `xyzr`,\n       needs to modified when View is instantiated. I.e. `xyzr` of `cell.branch(0)`,\n       should be `[self.base.xyzr[0]]` This could be achieved via:\n       `[self.base.xyzr[b] for b in self._branches_in_view]`.\n\n    For developers: If you want to add a new method to `Module`, here is an example of\n    how to make methods of Module compatible with View:\n\n    .. code-block:: python\n\n        # Use data in view to return something.\n        def count_small_branches(self):\n            # no need to use self.base.attr + viewed indices,\n            # since no change is made to the attr in question (nodes)\n            comp_lens = self.nodes[\"length\"]\n            branch_lens = comp_lens.groupby(\"global_branch_index\").sum()\n            return np.sum(branch_lens < 10)\n\n        # Change data in view.\n        def change_attr_in_view(self):\n            # changes to attrs have to be made via self.base.attr + viewed indices\n            a = func1(self.base.attr1[self._cells_in_view])\n            b = func2(self.base.attr2[self._edges_in_view])\n            self.base.attr3[self._branches_in_view] = a + b\n    \"\"\"\n\n    def __init__(self):\n        self.ncomp: int = None\n        self.total_nbranches: int = 0\n        self.nbranches_per_cell: List[int] = None\n\n        self.groups = {}\n\n        self.nodes: Optional[pd.DataFrame] = None\n        self._scope = \"local\"  # defaults to local scope\n        self._nodes_in_view: np.ndarray = None\n        self._edges_in_view: np.ndarray = None\n\n        self.edges = pd.DataFrame(\n            columns=[\n                \"global_edge_index\",\n                \"pre_global_comp_index\",\n                \"post_global_comp_index\",\n                \"pre_locs\",\n                \"post_locs\",\n                \"type\",\n                \"type_ind\",\n            ]\n        )\n\n        self._cumsum_nbranches: Optional[np.ndarray] = None\n\n        self.comb_parents: jnp.ndarray = jnp.asarray([-1])\n\n        self.initialized_morph: bool = False\n        self.initialized_syns: bool = False\n\n        # List of all types of `jx.Synapse`s.\n        self.synapses: List = []\n        self.synapse_param_names = []\n        self.synapse_state_names = []\n        self.synapse_names = []\n        self.synapse_current_names: List[str] = []\n\n        # List of types of all `jx.Channel`s.\n        self.channels: List[Channel] = []\n        self.membrane_current_names: List[str] = []\n\n        # For trainable parameters.\n        self.indices_set_by_trainables: List[jnp.ndarray] = []\n        self.trainable_params: List[Dict[str, jnp.ndarray]] = []\n        self.allow_make_trainable: bool = True\n        self.num_trainable_params: int = 0\n\n        # For recordings.\n        self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})\n\n        # For stimuli or clamps.\n        # E.g. `self.externals = {\"v\": zeros(1000,2), \"i\": ones(1000, 2)}`\n        # for 1000 timesteps and two compartments.\n        self.externals: Dict[str, jnp.ndarray] = {}\n        # E.g. `self.external)inds = {\"v\": jnp.asarray([0,1]), \"i\": jnp.asarray([2,3])}`\n        self.external_inds: Dict[str, jnp.ndarray] = {}\n\n        # x, y, z coordinates and radius.\n        self.xyzr: List[np.ndarray] = []\n        self._radius_generating_fns = None  # Defined by `.read_swc()`.\n\n        # For debugging the solver. Will be empty by default and only filled if\n        # `self._init_morph_for_debugging` is run.\n        self.debug_states = {}\n\n        # needs to be set at the end\n        self.base: Module = self\n\n    def __repr__(self):\n        return f\"{type(self).__name__} with {len(self.channels)} different channels. Use `.nodes` for details.\"\n\n    def __str__(self):\n        return f\"jx.{type(self).__name__}\"\n\n    def __dir__(self):\n        base_dir = object.__dir__(self)\n        return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys()))\n\n    def __getattr__(self, key):\n        # Ensure that hidden methods such as `__deepcopy__` still work.\n        if key.startswith(\"__\"):\n            return super().__getattribute__(key)\n\n        # intercepts calls to groups\n        if key in self.base.groups:\n            view = (\n                self.select(self.groups[key])\n                if key in self.groups\n                else self.select(None)\n            )\n            view._set_controlled_by_param(key)\n            return view\n\n        # intercepts calls to channels\n        if key in [c._name for c in self.base.channels]:\n            channel_names = [c._name for c in self.channels]\n            inds = self.nodes.index[self.nodes[key]].to_numpy()\n            view = self.select(inds) if key in channel_names else self.select(None)\n            view._set_controlled_by_param(key)\n            return view\n\n        # intercepts calls to synapse types\n        if key in self.base.synapse_names:\n            syn_inds = self.edges[self.edges[\"type\"] == key][\n                \"global_edge_index\"\n            ].to_numpy()\n            orig_scope = self._scope\n            view = (\n                self.scope(\"global\").edge(syn_inds).scope(orig_scope)\n                if key in self.synapse_names\n                else self.select(None)\n            )\n            view._set_controlled_by_param(key)  # overwrites param set by edge\n            # Ensure synapse param sharing works with `edge`\n            # `edge` will be removed as part of #463\n            view.edges[\"local_edge_index\"] = np.arange(len(view.edges))\n            return view\n\n    def _childviews(self) -> List[str]:\n        \"\"\"Returns levels that module can be viewed at.\n\n        I.e. for net -> [cell, branch, comp]. For branch -> [comp]\"\"\"\n        levels = [\"network\", \"cell\", \"branch\", \"comp\"]\n        if self._current_view in levels:\n            children = levels[levels.index(self._current_view) + 1 :]\n            return children\n        return []\n\n    def _has_childview(self, key: str) -> bool:\n        child_views = self._childviews()\n        return key in child_views\n\n    def __getitem__(self, index):\n        \"\"\"Lazy indexing of the module.\"\"\"\n        supported_parents = [\"network\", \"cell\", \"branch\"]  # cannot index into comp\n\n        not_group_view = self._current_view not in self.groups\n        assert (\n            self._current_view in supported_parents or not_group_view\n        ), \"Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof.\"\n        index = index if isinstance(index, tuple) else (index,)\n\n        child_views = self._childviews()\n        assert len(index) <= len(child_views), \"Too many indices.\"\n        view = self\n        for i, child in zip(index, child_views):\n            view = view._at_nodes(child, i)\n        return view\n\n    def _update_local_indices(self) -> pd.DataFrame:\n        \"\"\"Compute local indices from the global indices that are in view.\n        This is recomputed everytime a View is created.\"\"\"\n        rerank = lambda df: df.rank(method=\"dense\").astype(int) - 1\n\n        def reorder_cols(\n            df: pd.DataFrame, cols: List[str], first: bool = True\n        ) -> pd.DataFrame:\n            \"\"\"Move cols to front/back.\n\n            Args:\n                df: DataFrame to reorder.\n                cols: List of columns to place before/after remaining columns.\n                first: If True, cols are placed in front, otherwise at the end.\n\n            Returns:\n                DataFrame with reordered columns.\"\"\"\n            new_cols = [col for col in df.columns if first == (col in cols)]\n            new_cols += [col for col in df.columns if first != (col in cols)]\n            return df[new_cols]\n\n        def reindex_a_by_b(\n            df: pd.DataFrame, a: str, b: Optional[Union[str, List[str]]] = None\n        ) -> pd.DataFrame:\n            \"\"\"Reindex based on a different col or several columns\n            for b=[0,0,1,1,2,2,2] -> a=[0,1,0,1,0,1,2]\"\"\"\n            grouped_df = df.groupby(b) if b is not None else df\n            df.loc[:, a] = rerank(grouped_df[a])\n            return df\n\n        index_names = [\"cell_index\", \"branch_index\", \"comp_index\"]  # order is important\n        global_idx_cols = [f\"global_{name}\" for name in index_names]\n        local_idx_cols = [f\"local_{name}\" for name in index_names]\n        idcs = self.nodes[global_idx_cols]\n\n        # update local indices of nodes\n        idcs = reindex_a_by_b(idcs, global_idx_cols[0])\n        idcs = reindex_a_by_b(idcs, global_idx_cols[1], global_idx_cols[0])\n        idcs = reindex_a_by_b(idcs, global_idx_cols[2], global_idx_cols[:2])\n        idcs.columns = [col.replace(\"global\", \"local\") for col in global_idx_cols]\n        self.nodes[local_idx_cols] = idcs[local_idx_cols].astype(int)\n\n        # move indices to the front of the dataframe; move controlled_by_param to the end\n        # move indices of current scope to the front and the others to the back\n        not_scope = \"global\" if self._scope == \"local\" else \"local\"\n        self.nodes = reorder_cols(\n            self.nodes, [f\"{self._scope}_{name}\" for name in index_names], first=True\n        )\n        self.nodes = reorder_cols(\n            self.nodes, [f\"{not_scope}_{name}\" for name in index_names], first=False\n        )\n\n        self.edges = reorder_cols(self.edges, [\"global_edge_index\"])\n        self.nodes = reorder_cols(self.nodes, [\"controlled_by_param\"], first=False)\n        self.edges = reorder_cols(self.edges, [\"controlled_by_param\"], first=False)\n\n    def _init_view(self):\n        \"\"\"Init attributes critical for View.\n\n        Needs to be called at init of a Module.\"\"\"\n        parent = self.__class__.__name__.lower()\n        self._current_view = \"comp\" if parent == \"compartment\" else parent\n        self._nodes_in_view = self.nodes.index.to_numpy()\n        self._edges_in_view = self.edges.index.to_numpy()\n        self.nodes[\"controlled_by_param\"] = 0\n\n    def _compute_coords_of_comp_centers(self) -> np.ndarray:\n        \"\"\"Compute xyz coordinates of compartment centers.\n\n        Centers are the midpoint between the comparment endpoints on the morphology\n        as defined by xyzr.\n\n        Note: For sake of performance, interpolation is not done for each branch\n        individually, but only once along a concatenated (and padded) array of all branches.\n        This means for ncomps = [2,4] and normalized cum_branch_lens of [[0,1],[0,1]] we would\n        interpolate xyz at the locations comp_ends = [[0,0.5,1], [0,0.25,0.5,0.75,1]],\n        where 0 is the start of the branch and 1 is the end point at the full branch_len.\n        To avoid do this in one go we set comp_ends = [0,0.5,1,2,2.25,2.5,2.75,3], and\n        norm_cum_branch_len = [0,1,2,3] incrememting and also padding them by 1 to\n        avoid overlapping branch_lens i.e. norm_cum_branch_len = [0,1,1,2] for only\n        incrementing.\n        \"\"\"\n        nodes_by_branches = self.nodes.groupby(\"global_branch_index\")\n        ncomps = nodes_by_branches[\"global_comp_index\"].nunique().to_numpy()\n\n        comp_ends = [\n            np.linspace(0, 1, ncomp + 1) + 2 * i for i, ncomp in enumerate(ncomps)\n        ]\n        comp_ends = np.hstack(comp_ends)\n\n        comp_ends = comp_ends.reshape(-1)\n        cum_branch_lens = []\n        for i, xyzr in enumerate(self.xyzr):\n            branch_len = np.sqrt(np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1))\n            cum_branch_len = np.cumsum(np.concatenate([np.array([0]), branch_len]))\n            max_len = cum_branch_len.max()\n            # add padding like above\n            cum_branch_len = cum_branch_len / (max_len if max_len > 0 else 1) + 2 * i\n            cum_branch_len[np.isnan(cum_branch_len)] = 0\n            cum_branch_lens.append(cum_branch_len)\n        cum_branch_lens = np.hstack(cum_branch_lens)\n        xyz = np.vstack(self.xyzr)[:, :3]\n        xyz = v_interp(comp_ends, cum_branch_lens, xyz).T\n        centers = (xyz[:-1] + xyz[1:]) / 2  # unaware of inter vs intra comp centers\n        cum_ncomps = np.cumsum(ncomps)\n        # this means centers between comps have to be removed here\n        between_comp_inds = (cum_ncomps + np.arange(len(cum_ncomps)))[:-1]\n        centers = np.delete(centers, between_comp_inds, axis=0)\n        return centers\n\n    def compute_compartment_centers(self):\n        \"\"\"Add compartment centers to nodes dataframe\"\"\"\n        centers = self._compute_coords_of_comp_centers()\n        self.base.nodes.loc[self._nodes_in_view, [\"x\", \"y\", \"z\"]] = centers\n\n    def _reformat_index(self, idx: Any, dtype: type = int) -> np.ndarray:\n        \"\"\"Transforms different types of indices into an array.\n\n        Takes slice, list, array, ints, range and None and transforms\n        it into array of indices. If index == \"all\" it returns \"all\"\n        to be handled downstream.\n\n        Args:\n            idx: index that specifies at which locations to view the module.\n            dtype: defaults to int, but can also reformat float for use in `loc`\n\n        Returns:\n            array of indices of shape (N,)\"\"\"\n        if is_str_all(idx):  # also asserts that the only allowed str == \"all\"\n            return idx\n\n        np_dtype = np.int64 if dtype is int else np.float64\n        idx = np.array([], dtype=dtype) if idx is None else idx\n        idx = np.array([idx]) if isinstance(idx, (dtype, np_dtype)) else idx\n        idx = np.array(idx) if isinstance(idx, (list, range, pd.Index)) else idx\n\n        idx = np.arange(len(self.base.nodes))[idx] if isinstance(idx, slice) else idx\n        if idx.dtype == bool:\n            shape = (*self.shape, len(self.edges))\n            which_idx = len(idx) == np.array(shape)\n            assert np.any(which_idx), \"Index not matching num of cells/branches/comps.\"\n            dim = shape[np.where(which_idx)[0][0]]\n            idx = np.arange(dim)[idx]\n        assert isinstance(idx, np.ndarray), \"Invalid type\"\n        assert idx.dtype in [np_dtype, bool], \"Invalid dtype\"\n        return idx.reshape(-1)\n\n    def _set_controlled_by_param(self, key: str):\n        \"\"\"Determines which parameters are shared in `make_trainable`.\n\n        Adds column to nodes/edges dataframes to read of shared params from.\n\n        Args:\n            key: key specifying group / view that is in control of the params.\"\"\"\n        if key in [\"comp\", \"branch\", \"cell\"]:\n            self.nodes[\"controlled_by_param\"] = self.nodes[f\"global_{key}_index\"]\n            self.edges[\"controlled_by_param\"] = 0\n        elif key == \"edge\":\n            self.edges[\"controlled_by_param\"] = np.arange(len(self.edges))\n        elif key == \"filter\":\n            self.nodes[\"controlled_by_param\"] = np.arange(len(self.nodes))\n            self.edges[\"controlled_by_param\"] = np.arange(len(self.edges))\n        else:\n            self.nodes[\"controlled_by_param\"] = 0\n            self.edges[\"controlled_by_param\"] = 0\n        self._current_view = key\n\n    def select(\n        self, nodes: np.ndarray = None, edges: np.ndarray = None, sorted: bool = False\n    ) -> View:\n        \"\"\"Return View of the module filtered by specific node or edges indices.\n\n        Args:\n            nodes: indices of nodes to view. If None, all nodes are viewed.\n            edges: indices of edges to view. If None, all edges are viewed.\n            sorted: if True, nodes and edges are sorted.\n\n        Returns:\n            View for subset of selected nodes and/or edges.\"\"\"\n\n        nodes = self._reformat_index(nodes) if nodes is not None else None\n        nodes = self._nodes_in_view if is_str_all(nodes) else nodes\n        nodes = np.sort(nodes) if sorted else nodes\n\n        edges = self._reformat_index(edges) if edges is not None else None\n        edges = self._edges_in_view if is_str_all(edges) else edges\n        edges = np.sort(edges) if sorted else edges\n\n        view = View(self, nodes, edges)\n        view._set_controlled_by_param(\"filter\")\n        return view\n\n    def set_scope(self, scope: str):\n        \"\"\"Toggle between \"global\" or \"local\" scope.\n\n        Determines if global or local indices are used for viewing the module.\n\n        Args:\n            scope: either \"global\" or \"local\".\"\"\"\n        assert scope in [\"global\", \"local\"], \"Invalid scope.\"\n        self._scope = scope\n\n    def scope(self, scope: str) -> View:\n        \"\"\"Return a View of the module with the specified scope.\n\n        For example `cell.scope(\"global\").branch(2).scope(\"local\").comp(1)`\n        will return the 1st compartment of branch 2.\n\n        Args:\n            scope: either \"global\" or \"local\".\n\n        Returns:\n            View with the specified scope.\"\"\"\n        view = self.view\n        view.set_scope(scope)\n        return view\n\n    def _at_nodes(self, key: str, idx: Any) -> View:\n        \"\"\"Return a View of the module filtering `nodes` by specified key and index.\n\n        Keys can be `cell`, `branch`, `comp` and determine which index is used to filter.\n        \"\"\"\n        base_name = self.base.__class__.__name__\n        assert self.base._has_childview(key), f\"{base_name} does not support {key}.\"\n        idx = self._reformat_index(idx)\n        idx = self.nodes[self._scope + f\"_{key}_index\"] if is_str_all(idx) else idx\n        where = self.nodes[self._scope + f\"_{key}_index\"].isin(idx)\n        inds = self.nodes.index[where].to_numpy()\n\n        view = View(self, nodes=inds)\n        view._set_controlled_by_param(key)\n        return view\n\n    def _at_edges(self, key: str, idx: Any) -> View:\n        \"\"\"Return a View of the module filtering `edges` by specified key and index.\n\n        Keys can be `pre`, `post`, `edge` and determine which index is used to filter.\n        \"\"\"\n        idx = self._reformat_index(idx)\n        idx = self.edges[self._scope + f\"_{key}_index\"] if is_str_all(idx) else idx\n        where = self.edges[self._scope + f\"_{key}_index\"].isin(idx)\n        inds = self.edges.index[where].to_numpy()\n\n        view = View(self, edges=inds)\n        view._set_controlled_by_param(key)\n        return view\n\n    def cell(self, idx: Any) -> View:\n        \"\"\"Return a View of the module at the selected cell(s).\n\n        Args:\n            idx: index of the cell to view.\n\n        Returns:\n            View of the module at the specified cell index.\"\"\"\n        return self._at_nodes(\"cell\", idx)\n\n    def branch(self, idx: Any) -> View:\n        \"\"\"Return a View of the module at the selected branches(s).\n\n        Args:\n            idx: index of the branch to view.\n\n        Returns:\n            View of the module at the specified branch index.\"\"\"\n        return self._at_nodes(\"branch\", idx)\n\n    def comp(self, idx: Any) -> View:\n        \"\"\"Return a View of the module at the selected compartments(s).\n\n        Args:\n            idx: index of the comp to view.\n\n        Returns:\n            View of the module at the specified compartment index.\"\"\"\n        return self._at_nodes(\"comp\", idx)\n\n    def edge(self, idx: Any) -> View:\n        \"\"\"Return a View of the module at the selected synapse edges(s).\n\n        Args:\n            idx: index of the edge to view.\n\n        Returns:\n            View of the module at the specified edge index.\"\"\"\n        return self._at_edges(\"edge\", idx)\n\n    def loc(self, at: Any) -> View:\n        \"\"\"Return a View of the module at the selected branch location(s).\n\n        Args:\n            at: location along the branch.\n\n        Returns:\n            View of the module at the specified branch location.\"\"\"\n        global_comp_idxs = []\n        for i in self._branches_in_view:\n            ncomp = self.base.ncomp_per_branch[i]\n            comp_locs = np.linspace(0, 1, ncomp)\n            at = comp_locs if is_str_all(at) else self._reformat_index(at, dtype=float)\n            comp_edges = np.linspace(0, 1 + 1e-10, ncomp + 1)\n            idx = np.digitize(at, comp_edges) - 1 + self.base.cumsum_ncomp[i]\n            global_comp_idxs.append(idx)\n        global_comp_idxs = np.concatenate(global_comp_idxs)\n        orig_scope = self._scope\n        # global scope needed to select correct comps, for i.e. branches w. ncomp=[1,2]\n        # loc(0.9)  will correspond to different local branches (0 vs 1).\n        view = self.scope(\"global\").comp(global_comp_idxs).scope(orig_scope)\n        view._current_view = \"loc\"\n        return view\n\n    @property\n    def _comps_in_view(self):\n        \"\"\"Lists the global compartment indices which are currently part of the view.\"\"\"\n        # method also exists in View. this copy forgoes need to instantiate a View\n        return self.nodes[\"global_comp_index\"].unique()\n\n    @property\n    def _branches_in_view(self):\n        \"\"\"Lists the global branch indices which are currently part of the view.\"\"\"\n        # method also exists in View. this copy forgoes need to instantiate a View\n        return self.nodes[\"global_branch_index\"].unique()\n\n    @property\n    def _cells_in_view(self):\n        \"\"\"Lists the global cell indices which are currently part of the view.\"\"\"\n        # method also exists in View. this copy forgoes need to instantiate a View\n        return self.nodes[\"global_cell_index\"].unique()\n\n    def _iter_submodules(self, name: str):\n        \"\"\"Iterate over submoduleslevel.\n\n        Used for `cells`, `branches`, `comps`.\"\"\"\n        col = self._scope + f\"_{name}_index\"\n        idxs = self.nodes[col].unique()\n        for idx in idxs:\n            yield self._at_nodes(name, idx)\n\n    @property\n    def cells(self):\n        \"\"\"Iterate over all cells in the module.\n\n        Returns a generator that yields a View of each cell.\"\"\"\n        yield from self._iter_submodules(\"cell\")\n\n    @property\n    def branches(self):\n        \"\"\"Iterate over all branches in the module.\n\n        Returns a generator that yields a View of each branch.\"\"\"\n        yield from self._iter_submodules(\"branch\")\n\n    @property\n    def comps(self):\n        \"\"\"Iterate over all compartments in the module.\n        Can be called on any module, i.e. `net.comps`, `cell.comps` or\n        `branch.comps`. `__iter__` does not allow for this.\n\n        Returns a generator that yields a View of each compartment.\"\"\"\n        yield from self._iter_submodules(\"comp\")\n\n    def __iter__(self):\n        \"\"\"Iterate over parts of the module.\n\n        Internally calls `cells`, `branches`, `comps` at the appropriate level.\n\n        Example:\n\n        .. code-block:: python\n\n            for cell in network:\n                for branch in cell:\n                    for comp in branch:\n                        print(comp.nodes.shape)\n        \"\"\"\n        next_level = self._childviews()[0]\n        yield from self._iter_submodules(next_level)\n\n    @property\n    def shape(self) -> Tuple[int]:\n        \"\"\"Returns the number of submodules contained in a module.\n\n        .. code-block:: python\n\n            network.shape = (num_cells, num_branches, num_compartments)\n            cell.shape = (num_branches, num_compartments)\n            branch.shape = (num_compartments,)\n        \"\"\"\n        cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n        raw_shape = self.nodes[cols].nunique().to_list()\n\n        # ensure (net.shape -> dim=3, cell.shape -> dim=2, branch.shape -> dim=1, comp.shape -> dim=0)\n        levels = [\"network\", \"cell\", \"branch\", \"comp\"]\n        module = self.base.__class__.__name__.lower()\n        module = \"comp\" if module == \"compartment\" else module\n        shape = tuple(raw_shape[levels.index(module) :])\n        return shape\n\n    def copy(\n        self, reset_index: bool = False, as_module: bool = False\n    ) -> Union[Module, View]:\n        \"\"\"Extract part of a module and return a copy of its View or a new module.\n\n        This can be used to call `jx.integrate` on part of a Module.\n\n        Args:\n            reset_index: if True, the indices of the new module are reset to start from 0.\n            as_module: if True, a new module is returned instead of a View.\n\n        Returns:\n            A part of the module or a copied view of it.\"\"\"\n        view = deepcopy(self)\n        warnings.warn(\"This method is experimental, use at your own risk.\")\n        # TODO FROM #447: add reset_index, i.e. for parents, nodes, edges etc. such that they\n        # start from 0/-1 and are contiguous\n        if as_module:\n            raise NotImplementedError(\"Not yet implemented.\")\n            # initialize a new module with the same attributes\n        return view\n\n    @property\n    def view(self):\n        \"\"\"Return view of the module.\"\"\"\n        return View(self, self._nodes_in_view, self._edges_in_view)\n\n    @property\n    def _module_type(self):\n        \"\"\"Return type of the module (compartment, branch, cell, network) as string.\n\n        This is used to perform asserts for some modules (e.g. network cannot use\n        `set_ncomp`) without having to import the module in `base.py`.\"\"\"\n        return self.__class__.__name__.lower()\n\n    def _append_params_and_states(self, param_dict: Dict, state_dict: Dict):\n        \"\"\"Insert the default params of the module (e.g. radius, length).\n\n        This is run at `__init__()`. It does not deal with channels.\n        \"\"\"\n        for param_name, param_value in param_dict.items():\n            self.base.nodes[param_name] = param_value\n        for state_name, state_value in state_dict.items():\n            self.base.nodes[state_name] = state_value\n\n    def _gather_channels_from_constituents(self, constituents: List):\n        \"\"\"Modify `self.channels` and `self.nodes` with channel info from constituents.\n\n        This is run at `__init__()`. It takes all branches of constituents (e.g.\n        of all branches when the are assembled into a cell) and adds columns to\n        `.nodes` for the relevant channels.\n        \"\"\"\n        for module in constituents:\n            for channel in module.channels:\n                if channel._name not in [c._name for c in self.channels]:\n                    self.base.channels.append(channel)\n                if channel.current_name not in self.membrane_current_names:\n                    self.base.membrane_current_names.append(channel.current_name)\n        # Setting columns of channel names to `False` instead of `NaN`.\n        for channel in self.base.channels:\n            name = channel._name\n            self.base.nodes.loc[self.nodes[name].isna(), name] = False\n\n    @only_allow_module\n    def to_jax(self):\n        # TODO FROM #447: Make this work for View?\n        \"\"\"Move `.nodes` to `.jaxnodes`.\n\n        Before the actual simulation is run (via `jx.integrate`), all parameters of\n        the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n        simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n        they can be processed on GPU/TPU and such that the simulation can be\n        differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n        \"\"\"\n        self.base.jaxnodes = {}\n        for key, value in self.base.nodes.to_dict(orient=\"list\").items():\n            inds = jnp.arange(len(value))\n            self.base.jaxnodes[key] = jnp.asarray(value)[inds]\n\n        # `jaxedges` contains only parameters (no indices).\n        # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n        # we allow parameter sharing.\n        self.base.jaxedges = {}\n        edges = self.base.edges.to_dict(orient=\"list\")\n        for i, synapse in enumerate(self.base.synapses):\n            condition = np.asarray(edges[\"type_ind\"]) == i\n            for key in synapse.synapse_params:\n                self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n            for key in synapse.synapse_states:\n                self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n\n    def show(\n        self,\n        param_names: Optional[Union[str, List[str]]] = None,\n        *,\n        indices: bool = True,\n        params: bool = True,\n        states: bool = True,\n        channel_names: Optional[List[str]] = None,\n    ) -> pd.DataFrame:\n        \"\"\"Print detailed information about the Module or a view of it.\n\n        Args:\n            param_names: The names of the parameters to show. If `None`, all parameters\n                are shown.\n            indices: Whether to show the indices of the compartments.\n            params: Whether to show the parameters of the compartments.\n            states: Whether to show the states of the compartments.\n            channel_names: The names of the channels to show. If `None`, all channels are\n                shown.\n\n        Returns:\n            A `pd.DataFrame` with the requested information.\n        \"\"\"\n        nodes = self.nodes.copy()  # prevents this from being edited\n\n        cols = []\n        inds = [\"comp_index\", \"branch_index\", \"cell_index\"]\n        scopes = [\"local\", \"global\"]\n        inds = [f\"{s}_{i}\" for i in inds for s in scopes] if indices else []\n        cols += inds\n        cols += [ch._name for ch in self.channels] if channel_names else []\n        cols += (\n            sum([list(ch.channel_params) for ch in self.channels], []) if params else []\n        )\n        cols += (\n            sum([list(ch.channel_states) for ch in self.channels], []) if states else []\n        )\n\n        if not param_names is None:\n            cols = (\n                inds + [c for c in cols if c in param_names]\n                if params\n                else list(param_names)\n            )\n\n        return nodes[cols]\n\n    @only_allow_module\n    def _init_morph(self):\n        \"\"\"Initialize the morphology such that it can be processed by the solvers.\"\"\"\n        self._init_morph_jaxley_spsolve()\n        self._init_morph_jax_spsolve()\n        self.initialized_morph = True\n\n    @abstractmethod\n    def _init_morph_jax_spsolve(self):\n        \"\"\"Initialize the morphology for the JAX sparse solver.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _init_morph_jaxley_spsolve(self):\n        \"\"\"Initialize the morphology for the custom Jaxley solver.\"\"\"\n        raise NotImplementedError\n\n    def _compute_axial_conductances(self, params: Dict[str, jnp.ndarray]):\n        \"\"\"Given radius, length, r_a, compute the axial coupling conductances.\"\"\"\n        return compute_axial_conductances(self._comp_edges, params)\n\n    def set(self, key: str, val: Union[float, jnp.ndarray]):\n        \"\"\"Set parameter of module (or its view) to a new value.\n\n        Note that this function can not be called within `jax.jit` or `jax.grad`.\n        Instead, it should be used set the parameters of the module **before** the\n        simulation. Use `.data_set()` to set parameters during `jax.jit` or\n        `jax.grad`.\n\n        Args:\n            key: The name of the parameter to set.\n            val: The value to set the parameter to. If it is `jnp.ndarray` then it\n                must be of shape `(len(num_compartments))`.\n        \"\"\"\n        if key in self.nodes.columns:\n            not_nan = ~self.nodes[key].isna().to_numpy()\n            self.base.nodes.loc[self._nodes_in_view[not_nan], key] = val\n        elif key in self.edges.columns:\n            not_nan = ~self.edges[key].isna().to_numpy()\n            self.base.edges.loc[self._edges_in_view[not_nan], key] = val\n        else:\n            raise KeyError(f\"Key '{key}' not found in nodes or edges\")\n\n    def data_set(\n        self,\n        key: str,\n        val: Union[float, jnp.ndarray],\n        param_state: Optional[List[Dict]],\n    ):\n        \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n        Args:\n            key: The name of the parameter to set.\n            val: The value to set the parameter to. If it is `jnp.ndarray` then it\n                must be of shape `(len(num_compartments))`.\n            param_state: State of the setted parameters, internally used such that this\n                function does not modify global state.\n        \"\"\"\n        # Note: `data_set` does not support arrays for `val`.\n        is_node_param = key in self.nodes.columns\n        data = self.nodes if is_node_param else self.edges\n        viewed_inds = self._nodes_in_view if is_node_param else self._edges_in_view\n        if key in data.columns:\n            not_nan = ~data[key].isna()\n            added_param_state = [\n                {\n                    \"indices\": np.atleast_2d(viewed_inds[not_nan]),\n                    \"key\": key,\n                    \"val\": jnp.atleast_1d(jnp.asarray(val)),\n                }\n            ]\n            if param_state is not None:\n                param_state += added_param_state\n            else:\n                param_state = added_param_state\n        else:\n            raise KeyError(\"Key not recognized.\")\n        return param_state\n\n    def set_ncomp(\n        self,\n        ncomp: int,\n        min_radius: Optional[float] = None,\n    ):\n        \"\"\"Set the number of compartments with which the branch is discretized.\n\n        Args:\n            ncomp: The number of compartments that the branch should be discretized\n                into.\n            min_radius: Only used if the morphology was read from an SWC file. If passed\n                the radius is capped to be at least this value.\n\n        Raises:\n            - When there are stimuli in any compartment in the module.\n            - When there are recordings in any compartment in the module.\n            - When the channels of the compartments are not the same within the branch\n            that is modified.\n            - When the lengths of the compartments are not the same within the branch\n            that is modified.\n            - Unless the morphology was read from an SWC file, when the radiuses of the\n            compartments are not the same within the branch that is modified.\n        \"\"\"\n        assert len(self.base.externals) == 0, \"No stimuli allowed!\"\n        assert len(self.base.recordings) == 0, \"No recordings allowed!\"\n        assert len(self.base.trainable_params) == 0, \"No trainables allowed!\"\n\n        assert self.base._module_type != \"network\", \"This is not allowed for networks.\"\n        assert not (\n            self.base._module_type == \"cell\"\n            and len(self._branches_in_view) == len(self.base._branches_in_view)\n        ), \"This is not allowed for cells.\"\n\n        # Update all attributes that are affected by compartment structure.\n        view = self.nodes.copy()\n        all_nodes = self.base.nodes\n        start_idx = self.nodes[\"global_comp_index\"].to_numpy()[0]\n        ncomp_per_branch = self.base.ncomp_per_branch\n        channel_names = [c._name for c in self.base.channels]\n        channel_param_names = list(\n            chain(*[c.channel_params for c in self.base.channels])\n        )\n        channel_state_names = list(\n            chain(*[c.channel_states for c in self.base.channels])\n        )\n        radius_generating_fns = self.base._radius_generating_fns\n\n        within_branch_radiuses = view[\"radius\"].to_numpy()\n        compartment_lengths = view[\"length\"].to_numpy()\n        num_previous_ncomp = len(within_branch_radiuses)\n        branch_indices = pd.unique(view[\"global_branch_index\"])\n\n        error_msg = lambda name: (\n            f\"You previously modified the {name} of individual compartments, but \"\n            f\"now you are modifying the number of compartments in this branch. \"\n            f\"This is not allowed. First build the morphology with `set_ncomp()` and \"\n            f\"then modify the radiuses and lengths of compartments.\"\n        )\n\n        if (\n            ~np.all(within_branch_radiuses == within_branch_radiuses[0])\n            and radius_generating_fns is None\n        ):\n            raise ValueError(error_msg(\"radius\"))\n\n        for property_name in [\"length\", \"capacitance\", \"axial_resistivity\"]:\n            compartment_properties = view[property_name].to_numpy()\n            if ~np.all(compartment_properties == compartment_properties[0]):\n                raise ValueError(error_msg(property_name))\n\n        if not (self.nodes[channel_names].var() == 0.0).all():\n            raise ValueError(\n                \"Some channel exists only in some compartments of the branch which you\"\n                \"are trying to modify. This is not allowed. First specify the number\"\n                \"of compartments with `.set_ncomp()` and then insert the channels\"\n                \"accordingly.\"\n            )\n\n        if not (\n            self.nodes[channel_param_names + channel_state_names].var() == 0.0\n        ).all():\n            raise ValueError(\n                \"Some channel has different parameters or states between the \"\n                \"different compartments of the branch which you are trying to modify. \"\n                \"This is not allowed. First specify the number of compartments with \"\n                \"`.set_ncomp()` and then insert the channels accordingly.\"\n            )\n\n        # Add new rows as the average of all rows. Special case for the length is below.\n        average_row = self.nodes.mean(skipna=False)\n        average_row = average_row.to_frame().T\n        view = pd.concat([*[average_row] * ncomp], axis=\"rows\")\n\n        # Set the correct datatype after having performed an average which cast\n        # everything to float.\n        integer_cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n        view[integer_cols] = view[integer_cols].astype(int)\n\n        # Whether or not a channel exists in a compartment is a boolean.\n        boolean_cols = channel_names\n        view[boolean_cols] = view[boolean_cols].astype(bool)\n\n        # Special treatment for the lengths and radiuses. These are not being set as\n        # the average because we:\n        # 1) Want to maintain the total length of a branch.\n        # 2) Want to use the SWC inferred radius.\n        #\n        # Compute new compartment lengths.\n        comp_lengths = np.sum(compartment_lengths) / ncomp\n        view[\"length\"] = comp_lengths\n\n        # Compute new compartment radiuses.\n        if radius_generating_fns is not None:\n            view[\"radius\"] = build_radiuses_from_xyzr(\n                radius_fns=radius_generating_fns,\n                branch_indices=branch_indices,\n                min_radius=min_radius,\n                ncomp=ncomp,\n            )\n        else:\n            view[\"radius\"] = within_branch_radiuses[0] * np.ones(ncomp)\n\n        # Update `.nodes`.\n        # 1) Delete N rows starting from start_idx\n        number_deleted = num_previous_ncomp\n        all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted))\n\n        # 2) Insert M new rows at the same location\n        df1 = all_nodes.iloc[:start_idx]  # Rows before the insertion point\n        df2 = all_nodes.iloc[start_idx:]  # Rows after the insertion point\n\n        # 3) Combine the parts: before, new rows, and after\n        all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True)\n\n        # Override `comp_index` to just be a consecutive list.\n        all_nodes[\"global_comp_index\"] = np.arange(len(all_nodes))\n\n        # Update compartment structure arguments.\n        ncomp_per_branch[branch_indices] = ncomp\n        ncomp = int(np.max(ncomp_per_branch))\n        cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n        internal_node_inds = np.arange(cumsum_ncomp[-1])\n\n        self.base.nodes = all_nodes\n        self.base.ncomp_per_branch = ncomp_per_branch\n        self.base.ncomp = ncomp\n        self.base.cumsum_ncomp = cumsum_ncomp\n        self.base._internal_node_inds = internal_node_inds\n\n        # Update the morphology indexing (e.g., `.comp_edges`).\n        self.base._initialize()\n        self.base._init_view()\n        self.base._update_local_indices()\n\n    def make_trainable(\n        self,\n        key: str,\n        init_val: Optional[Union[float, list]] = None,\n        verbose: bool = True,\n    ):\n        \"\"\"Make a parameter trainable.\n\n        If a parameter is made trainable, it will be returned by `get_parameters()`\n        and should then be passed to `jx.integrate(..., params=params)`.\n\n        Args:\n            key: Name of the parameter to make trainable.\n            init_val: Initial value of the parameter. If `float`, the same value is\n                used for every created parameter. If `list`, the length of the list has\n                to match the number of created parameters. If `None`, the current\n                parameter value is used and if parameter sharing is performed that the\n                current parameter value is averaged over all shared parameters.\n            verbose: Whether to print the number of parameters that are added and the\n                total number of parameters.\n        \"\"\"\n        assert (\n            self.allow_make_trainable\n        ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n        ncomps_per_branch = (\n            self.base.nodes[\"global_branch_index\"].value_counts().to_numpy()\n        )\n        assert np.all(\n            ncomps_per_branch == ncomps_per_branch[0]\n        ), \"Parameter sharing is not allowed for modules containing branches with different numbers of compartments.\"\n\n        data = self.nodes if key in self.nodes.columns else None\n        data = self.edges if key in self.edges.columns else data\n\n        assert data is not None, f\"Key '{key}' not found in nodes or edges\"\n        not_nan = ~data[key].isna()\n        data = data.loc[not_nan]\n        assert (\n            len(data) > 0\n        ), \"No settable parameters found in the selected compartments.\"\n\n        grouped_view = data.groupby(\"controlled_by_param\")\n        # Because of this `x.index.values` we cannot support `make_trainable()` on\n        # the module level for synapse parameters (but only for `SynapseView`).\n        inds_of_comps = list(\n            grouped_view.apply(lambda x: x.index.values, include_groups=False)\n        )\n        indices_per_param = jnp.stack(inds_of_comps)\n        # Sorted inds are only used to infer the correct starting values.\n        param_vals = jnp.asarray(\n            [data.loc[inds, key].to_numpy() for inds in inds_of_comps]\n        )\n\n        # Set the value which the trainable parameter should take.\n        num_created_parameters = len(indices_per_param)\n        if init_val is not None:\n            if isinstance(init_val, float):\n                new_params = jnp.asarray([init_val] * num_created_parameters)\n            elif isinstance(init_val, list):\n                assert (\n                    len(init_val) == num_created_parameters\n                ), f\"len(init_val)={len(init_val)}, but trying to create {num_created_parameters} parameters.\"\n                new_params = jnp.asarray(init_val)\n            else:\n                raise ValueError(\n                    f\"init_val must a float, list, or None, but it is a {type(init_val).__name__}.\"\n                )\n        else:\n            new_params = jnp.mean(param_vals, axis=1)\n        self.base.trainable_params.append({key: new_params})\n        self.base.indices_set_by_trainables.append(indices_per_param)\n        self.base.num_trainable_params += num_created_parameters\n        if verbose:\n            print(\n                f\"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}\"\n            )\n\n    def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):\n        \"\"\"Write the trainables into `.nodes` and `.edges`.\n\n        This allows to, e.g., visualize trained networks with `.vis()`.\n\n        Args:\n            trainable_params: The trainable parameters returned by `get_parameters()`.\n        \"\"\"\n        # We do not support views. Why? `jaxedges` does not have any NaN\n        # elements, whereas edges does. Because of this, we already need special\n        # treatment to make this function work, and it would be an even bigger hassle\n        # if we wanted to support this.\n        assert self.__class__.__name__ in [\n            \"Compartment\",\n            \"Branch\",\n            \"Cell\",\n            \"Network\",\n        ], \"Only supports modules.\"\n\n        # We could also implement this without casting the module to jax.\n        # However, I think it allows us to reuse as much code as possible and it avoids\n        # any kind of issues with indexing or parameter sharing (as this is fully\n        # taken care of by `get_all_parameters()`).\n        self.base.to_jax()\n        pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)\n        all_params = self.base.get_all_parameters(pstate, voltage_solver=\"jaxley.stone\")\n\n        # The value for `delta_t` does not matter here because it is only used to\n        # compute the initial current. However, the initial current cannot be made\n        # trainable and so its value never gets used below.\n        all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025)\n\n        # Loop only over the keys in `pstate` to avoid unnecessary computation.\n        for parameter in pstate:\n            key = parameter[\"key\"]\n            if key in self.base.nodes.columns:\n                vals_to_set = all_params if key in all_params.keys() else all_states\n                self.base.nodes[key] = vals_to_set[key]\n\n        # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n        # we allow parameter sharing.\n        edges = self.base.edges.to_dict(orient=\"list\")\n        for i, synapse in enumerate(self.base.synapses):\n            condition = np.asarray(edges[\"type_ind\"]) == i\n            for key in list(synapse.synapse_params.keys()):\n                self.base.edges.loc[condition, key] = all_params[key]\n            for key in list(synapse.synapse_states.keys()):\n                self.base.edges.loc[condition, key] = all_states[key]\n\n    def distance(self, endpoint: \"View\") -> float:\n        \"\"\"Return the direct distance between two compartments.\n        This does not compute the pathwise distance (which is currently not\n        implemented).\n        Args:\n            endpoint: The compartment to which to compute the distance to.\n        \"\"\"\n        assert len(self.xyzr) == 1 and len(endpoint.xyzr) == 1\n        start_xyz = np.mean(self.xyzr[0][:, :3], axis=0)\n        end_xyz = np.mean(endpoint.xyzr[0][:, :3], axis=0)\n        return np.sqrt(np.sum((start_xyz - end_xyz) ** 2))\n\n    def delete_trainables(self):\n        \"\"\"Removes all trainable parameters from the module.\"\"\"\n\n        if isinstance(self, View):\n            trainables_and_inds = self._filter_trainables(is_viewed=False)\n            self.base.indices_set_by_trainables = trainables_and_inds[0]\n            self.base.trainable_params = trainables_and_inds[1]\n            self.base.num_trainable_params -= self.num_trainable_params\n        else:\n            self.base.indices_set_by_trainables = []\n            self.base.trainable_params = []\n            self.base.num_trainable_params = 0\n        self._update_view()\n\n    def add_to_group(self, group_name: str):\n        \"\"\"Add a view of the module to a group.\n\n        Groups can then be indexed. For example:\n\n        .. code-block:: python\n\n            net.cell(0).add_to_group(\"excitatory\")\n            net.excitatory.set(\"radius\", 0.1)\n\n        Args:\n            group_name: The name of the group.\n        \"\"\"\n        if group_name not in self.base.groups:\n            self.base.groups[group_name] = self._nodes_in_view\n        else:\n            self.base.groups[group_name] = np.unique(\n                np.concatenate([self.base.groups[group_name], self._nodes_in_view])\n            )\n\n    def _get_state_names(self) -> Tuple[List, List]:\n        \"\"\"Collect all recordable / clampable states in the membrane and synapses.\n\n        Returns states seperated by comps and edges.\"\"\"\n        channel_states = [name for c in self.channels for name in c.channel_states]\n        synapse_states = [\n            name for s in self.synapses if s is not None for name in s.synapse_states\n        ]\n        membrane_states = [\"v\", \"i\"] + self.membrane_current_names\n        return (\n            channel_states + membrane_states,\n            synapse_states + self.synapse_current_names,\n        )\n\n    def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n        \"\"\"Get all trainable parameters.\n\n        The returned parameters should be passed to `jx.integrate(..., params=params).\n\n        Returns:\n            A list of all trainable parameters in the form of\n                [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n        \"\"\"\n        return self.trainable_params\n\n    @only_allow_module\n    def get_all_parameters(\n        self, pstate: List[Dict], voltage_solver: str\n    ) -> Dict[str, jnp.ndarray]:\n        # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n        \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n        Runs `_compute_axial_conductances()` and return every parameter that is needed\n        to solve the ODE. This includes conductances, radiuses, lengths,\n        axial_resistivities, but also coupling conductances.\n\n        This is done by first obtaining the current value of every parameter (not only\n        the trainable ones) and then replacing the trainable ones with the value\n        in `trainable_params()`. This function is run within `jx.integrate()`.\n\n        pstate can be obtained by calling `params_to_pstate()`.\n\n        .. code-block:: python\n\n            params = module.get_parameters() # i.e. [0, 1, 2]\n            pstate = params_to_pstate(params, module.indices_set_by_trainables)\n            module.to_jax() # needed for call to module.jaxnodes\n\n        Args:\n            pstate: The state of the trainable parameters. pstate takes the form\n                [{\n                    \"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]),\n                    \"val\": jnp.array([0.1, 0.2, 0.3])\n                }, ...].\n            voltage_solver: The voltage solver that is used. Since `jax.sparse` and\n                `jaxley.xyz` require different formats of the axial conductances, this\n                function will default to different building methods.\n\n        Returns:\n            A dictionary of all module parameters.\n        \"\"\"\n        params = {}\n        for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n            params[key] = self.base.jaxnodes[key]\n\n        for channel in self.base.channels:\n            for channel_params in channel.channel_params:\n                params[channel_params] = self.base.jaxnodes[channel_params]\n\n        for synapse_params in self.base.synapse_param_names:\n            params[synapse_params] = self.base.jaxedges[synapse_params]\n\n        # Override with those parameters set by `.make_trainable()`.\n        for parameter in pstate:\n            key = parameter[\"key\"]\n            inds = parameter[\"indices\"]\n            set_param = parameter[\"val\"]\n\n            # This is needed since SynapseViews worked differently before.\n            # This mimics the old behaviour and tranformes the new indices\n            # to the old indices.\n            # TODO FROM #447: Longterm this should be gotten rid of.\n            # Instead edges should work similar to nodes (would also allow for\n            # param sharing).\n            synapse_inds = self.base.edges.groupby(\"type\").rank()[\"global_edge_index\"]\n            synapse_inds = (synapse_inds.astype(int) - 1).to_numpy()\n            if key in self.base.synapse_param_names:\n                inds = synapse_inds[inds]\n\n            if key in params:  # Only parameters, not initial states.\n                # `inds` is of shape `(num_params, num_comps_per_param)`.\n                # `set_param` is of shape `(num_params,)`\n                # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n                # `.set()` to work. This is done with `[:, None]`.\n                params[key] = params[key].at[inds].set(set_param[:, None])\n\n        # Compute conductance params and add them to the params dictionary.\n        params[\"axial_conductances\"] = self.base._compute_axial_conductances(\n            params=params\n        )\n        return params\n\n    @only_allow_module\n    def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]:\n        # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n        \"\"\"Return states as they are set in the `.nodes` and `.edges` tables.\"\"\"\n        self.base.to_jax()  # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.\n        states = {\"v\": self.base.jaxnodes[\"v\"]}\n        # Join node and edge states into a single state dictionary.\n        for channel in self.base.channels:\n            for channel_states in channel.channel_states:\n                states[channel_states] = self.base.jaxnodes[channel_states]\n        for synapse_states in self.base.synapse_state_names:\n            states[synapse_states] = self.base.jaxedges[synapse_states]\n        return states\n\n    @only_allow_module\n    def get_all_states(\n        self, pstate: List[Dict], all_params, delta_t: float\n    ) -> Dict[str, jnp.ndarray]:\n        # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n        \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n        Args:\n            pstate: The state of the trainable parameters.\n            all_params: All parameters of the module.\n            delta_t: The time step.\n\n        Returns:\n            A dictionary of all states of the module.\n        \"\"\"\n        states = self.base._get_states_from_nodes_and_edges()\n\n        # Override with the initial states set by `.make_trainable()`.\n        for parameter in pstate:\n            key = parameter[\"key\"]\n            inds = parameter[\"indices\"]\n            set_param = parameter[\"val\"]\n            if key in states:  # Only initial states, not parameters.\n                # `inds` is of shape `(num_params, num_comps_per_param)`.\n                # `set_param` is of shape `(num_params,)`\n                # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n                # `.set()` to work. This is done with `[:, None]`.\n                states[key] = states[key].at[inds].set(set_param[:, None])\n\n        # Add to the states the initial current through every channel.\n        states, _ = self.base._channel_currents(\n            states, delta_t, self.channels, self.nodes, all_params\n        )\n\n        # Add to the states the initial current through every synapse.\n        states, _ = self.base._synapse_currents(\n            states, self.synapses, all_params, delta_t, self.edges\n        )\n        return states\n\n    @property\n    def initialized(self) -> bool:\n        \"\"\"Whether the `Module` is ready to be solved or not.\"\"\"\n        return self.initialized_morph\n\n    def _initialize(self):\n        \"\"\"Initialize the module.\"\"\"\n        self._init_morph()\n        return self\n\n    @only_allow_module\n    def init_states(self, delta_t: float = 0.025):\n        # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n        \"\"\"Initialize all mechanisms in their steady state.\n\n        This considers the voltages and parameters of each compartment.\n\n        Args:\n            delta_t: Passed on to `channel.init_state()`.\n        \"\"\"\n        # Update states of the channels.\n        channel_nodes = self.base.nodes\n        states = self.base._get_states_from_nodes_and_edges()\n\n        # We do not use any `pstate` for initializing. In principle, we could change\n        # that by allowing an input `params` and `pstate` to this function.\n        # `voltage_solver` could also be `jax.sparse` here, because both of them\n        # build the channel parameters in the same way.\n        params = self.base.get_all_parameters([], voltage_solver=\"jaxley.thomas\")\n\n        for channel in self.base.channels:\n            name = channel._name\n            channel_indices = channel_nodes.loc[channel_nodes[name]][\n                \"global_comp_index\"\n            ].to_numpy()\n            voltages = channel_nodes.loc[channel_indices, \"v\"].to_numpy()\n\n            channel_param_names = list(channel.channel_params.keys())\n            channel_state_names = list(channel.channel_states.keys())\n            channel_states = query_channel_states_and_params(\n                states, channel_state_names, channel_indices\n            )\n            channel_params = query_channel_states_and_params(\n                params, channel_param_names, channel_indices\n            )\n\n            init_state = channel.init_state(\n                channel_states, voltages, channel_params, delta_t\n            )\n\n            # `init_state` might not return all channel states. Only the ones that are\n            # returned are updated here.\n            for key, val in init_state.items():\n                # Note that we are overriding `self.nodes` here, but `self.nodes` is\n                # not used above to actually compute the current states (so there are\n                # no issues with overriding states).\n                self.nodes.loc[channel_indices, key] = val\n\n    def _init_morph_for_debugging(self):\n        \"\"\"Instandiates row and column inds which can be used to solve the voltage eqs.\n\n        This is important only for expert users who try to modify the solver for the\n        voltage equations. By default, this function is never run.\n\n        This is useful for debugging the solver because one can use\n        `scipy.linalg.sparse.spsolve` after every step of the solve.\n\n        Here is the code snippet that can be used for debugging then (to be inserted in\n        `solver_voltage`):\n        ```python\n        from scipy.sparse import csc_matrix\n        from scipy.sparse.linalg import spsolve\n        from jaxley.utils.debug_solver import build_voltage_matrix_elements\n\n        elements, solve, num_entries, start_ind_for_branchpoints = (\n            build_voltage_matrix_elements(\n                uppers,\n                lowers,\n                diags,\n                solves,\n                branchpoint_conds_children[debug_states[\"child_inds\"]],\n                branchpoint_conds_parents[debug_states[\"par_inds\"]],\n                branchpoint_weights_children[debug_states[\"child_inds\"]],\n                branchpoint_weights_parents[debug_states[\"par_inds\"]],\n                branchpoint_diags,\n                branchpoint_solves,\n                debug_states[\"ncomp\"],\n                nbranches,\n            )\n        )\n        sparse_matrix = csc_matrix(\n            (elements, (debug_states[\"row_inds\"], debug_states[\"col_inds\"])),\n            shape=(num_entries, num_entries),\n        )\n        solution = spsolve(sparse_matrix, solve)\n        solution = solution[:start_ind_for_branchpoints]  # Delete branchpoint voltages.\n        solves = jnp.reshape(solution, (debug_states[\"ncomp\"], nbranches))\n        return solves\n        ```\n        \"\"\"\n        # For scipy and jax.scipy.\n        row_and_col_inds = compute_morphology_indices(\n            len(self.base._par_inds),\n            self.base._child_belongs_to_branchpoint,\n            self.base._par_inds,\n            self.base._child_inds,\n            self.base.ncomp,\n            self.base.total_nbranches,\n        )\n\n        num_elements = len(row_and_col_inds[\"row_inds\"])\n        data_inds, indices, indptr = convert_to_csc(\n            num_elements=num_elements,\n            row_ind=row_and_col_inds[\"row_inds\"],\n            col_ind=row_and_col_inds[\"col_inds\"],\n        )\n        self.base.debug_states[\"row_inds\"] = row_and_col_inds[\"row_inds\"]\n        self.base.debug_states[\"col_inds\"] = row_and_col_inds[\"col_inds\"]\n        self.base.debug_states[\"data_inds\"] = data_inds\n        self.base.debug_states[\"indices\"] = indices\n        self.base.debug_states[\"indptr\"] = indptr\n\n        self.base.debug_states[\"ncomp\"] = self.base.ncomp\n        self.base.debug_states[\"child_inds\"] = self.base._child_inds\n        self.base.debug_states[\"par_inds\"] = self.base._par_inds\n\n    def record(self, state: str = \"v\", verbose=True):\n        comp_states, edge_states = self._get_state_names()\n        if state not in comp_states + edge_states:\n            raise KeyError(f\"{state} is not a recognized state in this module.\")\n        in_view = self._nodes_in_view if state in comp_states else self._edges_in_view\n\n        new_recs = pd.DataFrame(in_view, columns=[\"rec_index\"])\n        new_recs[\"state\"] = state\n        self.base.recordings = pd.concat([self.base.recordings, new_recs])\n        has_duplicates = self.base.recordings.duplicated()\n        self.base.recordings = self.base.recordings.loc[~has_duplicates]\n        if verbose:\n            print(\n                f\"Added {len(in_view)-sum(has_duplicates)} recordings. See `.recordings` for details.\"\n            )\n\n    def _update_view(self):\n        \"\"\"Update the attrs of the view after changes in the base module.\"\"\"\n        if isinstance(self, View):\n            scope = self._scope\n            current_view = self._current_view\n            # copy dict of new View. For some reason doing self = View(self)\n            # did not work.\n            self.__dict__ = View(\n                self.base, self._nodes_in_view, self._edges_in_view\n            ).__dict__\n\n            # retain the scope and current_view of the previous view\n            self._scope = scope\n            self._current_view = current_view\n\n    def delete_recordings(self):\n        \"\"\"Removes all recordings from the module.\"\"\"\n        if isinstance(self, View):\n            base_recs = self.base.recordings\n            self.base.recordings = base_recs[\n                ~base_recs.isin(self.recordings).all(axis=1)\n            ]\n            self._update_view()\n        else:\n            self.base.recordings = pd.DataFrame().from_dict({})\n\n    def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n        \"\"\"Insert a stimulus into the compartment.\n\n        current must be a 1d array or have batch dimension of size `(num_compartments, )`\n        or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n        This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n        it should only be used for static stimuli (i.e., stimuli that do not depend\n        on the data and that should not be learned). For stimuli that depend on data\n        (or that should be learned), please use `data_stimulate()`.\n\n        Args:\n            current: Current in `nA`.\n        \"\"\"\n        self._external_input(\"i\", current, verbose=verbose)\n\n    def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n        \"\"\"Clamp a state to a given value across specified compartments.\n\n        Args:\n            state_name: The name of the state to clamp.\n            state_array (jnp.nd: Array of values to clamp the state to.\n            verbose : If True, prints details about the clamping.\n\n        This function sets external states for the compartments.\n        \"\"\"\n        self._external_input(state_name, state_array, verbose=verbose)\n\n    def _external_input(\n        self,\n        key: str,\n        values: Optional[jnp.ndarray],\n        verbose: bool = True,\n    ):\n        comp_states, edge_states = self._get_state_names()\n        if key not in comp_states + edge_states:\n            raise KeyError(f\"{key} is not a recognized state in this module.\")\n        values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0)\n        batch_size = values.shape[0]\n        num_inserted = (\n            len(self._nodes_in_view) if key in comp_states else len(self._edges_in_view)\n        )\n        is_multiple = num_inserted == batch_size\n        values = values if is_multiple else jnp.repeat(values, num_inserted, axis=0)\n        assert batch_size in [\n            1,\n            num_inserted,\n        ], \"Number of comps and stimuli do not match.\"\n\n        if key in self.base.externals.keys():\n            self.base.externals[key] = jnp.concatenate(\n                [self.base.externals[key], values]\n            )\n            self.base.external_inds[key] = jnp.concatenate(\n                [self.base.external_inds[key], self._nodes_in_view]\n            )\n        else:\n            if key in comp_states:\n                self.base.externals[key] = values\n                self.base.external_inds[key] = self._nodes_in_view\n            else:\n                self.base.externals[key] = values\n                self.base.external_inds[key] = self._edges_in_view\n        if verbose:\n            print(\n                f\"Added {num_inserted} external_states. See `.externals` for details.\"\n            )\n\n    def data_stimulate(\n        self,\n        current: jnp.ndarray,\n        data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n        verbose: bool = False,\n    ) -> Tuple[jnp.ndarray, pd.DataFrame]:\n        \"\"\"Insert a stimulus into the module within jit (or grad).\n\n        Args:\n            current: Current in `nA`.\n            verbose: Whether or not to print the number of inserted stimuli. `False`\n                by default because this method is meant to be jitted.\n        \"\"\"\n        return self._data_external_input(\n            \"i\", current, data_stimuli, self.nodes, verbose=verbose\n        )\n\n    def data_clamp(\n        self,\n        state_name: str,\n        state_array: jnp.ndarray,\n        data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n        verbose: bool = False,\n    ):\n        \"\"\"Insert a clamp into the module within jit (or grad).\n\n        Args:\n            state_name: Name of the state variable to set.\n            state_array: Time series of the state variable in the default Jaxley unit.\n                State array should be of shape (num_clamps, simulation_time) or\n                (simulation_time, ) for a single clamp.\n            verbose: Whether or not to print the number of inserted clamps. `False`\n                by default because this method is meant to be jitted.\n        \"\"\"\n        comp_states, edge_states = self._get_state_names()\n        if state_name not in comp_states + edge_states:\n            raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n        data = self.nodes if state_name in comp_states else self.edges\n        return self._data_external_input(\n            state_name, state_array, data_clamps, data, verbose=verbose\n        )\n\n    def _data_external_input(\n        self,\n        state_name: str,\n        state_array: jnp.ndarray,\n        data_external_input: Optional[Tuple[jnp.ndarray, pd.DataFrame]],\n        view: pd.DataFrame,\n        verbose: bool = False,\n    ):\n        comp_states, edge_states = self._get_state_names()\n        state_array = (\n            state_array\n            if state_array.ndim == 2\n            else jnp.expand_dims(state_array, axis=0)\n        )\n        batch_size = state_array.shape[0]\n        num_inserted = (\n            len(self._nodes_in_view)\n            if state_name in comp_states\n            else len(self._edges_in_view)\n        )\n        is_multiple = num_inserted == batch_size\n        state_array = (\n            state_array\n            if is_multiple\n            else jnp.repeat(state_array, num_inserted, axis=0)\n        )\n        assert batch_size in [\n            1,\n            num_inserted,\n        ], \"Number of comps and clamps do not match.\"\n\n        if data_external_input is not None:\n            external_input = data_external_input[1]\n            external_input = jnp.concatenate([external_input, state_array])\n            inds = data_external_input[2]\n        else:\n            external_input = state_array\n            inds = pd.DataFrame().from_dict({})\n\n        inds = pd.concat([inds, view])\n\n        if verbose:\n            if state_name == \"i\":\n                print(f\"Added {len(view)} stimuli.\")\n            else:\n                print(f\"Added {len(view)} clamps.\")\n\n        return (state_name, external_input, inds)\n\n    def delete_stimuli(self):\n        \"\"\"Removes all stimuli from the module.\"\"\"\n        self.delete_clamps(\"i\")\n\n    def delete_clamps(self, state_name: Optional[str] = None):\n        \"\"\"Removes all clamps of the given state from the module.\"\"\"\n        all_externals = list(self.externals.keys())\n        if \"i\" in all_externals:\n            all_externals.remove(\"i\")\n        state_names = all_externals if state_name is None else [state_name]\n        for state_name in state_names:\n            if state_name in self.externals:\n                keep_inds = ~np.isin(\n                    self.base.external_inds[state_name], self._nodes_in_view\n                )\n                base_exts = self.base.externals\n                base_exts_inds = self.base.external_inds\n                if np.all(~keep_inds):\n                    base_exts.pop(state_name, None)\n                    base_exts_inds.pop(state_name, None)\n                else:\n                    base_exts[state_name] = base_exts[state_name][keep_inds]\n                    base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]\n                self._update_view()\n            else:\n                pass  # does not have to be deleted if not in externals\n\n    def insert(self, channel: Channel):\n        \"\"\"Insert a channel into the module.\n\n        Args:\n            channel: The channel to insert.\"\"\"\n        name = channel._name\n\n        # Channel does not yet exist in the `jx.Module` at all.\n        if name not in [c._name for c in self.base.channels]:\n            self.base.channels.append(channel)\n            self.base.nodes[name] = (\n                False  # Previous columns do not have the new channel.\n            )\n\n        if channel.current_name not in self.base.membrane_current_names:\n            self.base.membrane_current_names.append(channel.current_name)\n\n        # Add a binary column that indicates if a channel is present.\n        self.base.nodes.loc[self._nodes_in_view, name] = True\n\n        # Loop over all new parameters, e.g. gNa, eNa.\n        for key in channel.channel_params:\n            self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key]\n\n        # Loop over all new parameters, e.g. gNa, eNa.\n        for key in channel.channel_states:\n            self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]\n\n    def delete_channel(self, channel: Channel):\n        \"\"\"Remove a channel from the module.\n\n        Args:\n            channel: The channel to remove.\"\"\"\n        name = channel._name\n        channel_names = [c._name for c in self.channels]\n        all_channel_names = [c._name for c in self.base.channels]\n        if name in channel_names:\n            channel_cols = list(channel.channel_params.keys())\n            channel_cols += list(channel.channel_states.keys())\n            self.base.nodes.loc[self._nodes_in_view, channel_cols] = float(\"nan\")\n            self.base.nodes.loc[self._nodes_in_view, name] = False\n\n            # only delete cols if no other comps in the module have the same channel\n            if np.all(~self.base.nodes[name]):\n                self.base.channels.pop(all_channel_names.index(name))\n                self.base.membrane_current_names.remove(channel.current_name)\n                self.base.nodes.drop(columns=channel_cols + [name], inplace=True)\n        else:\n            raise ValueError(f\"Channel {name} not found in the module.\")\n\n    @only_allow_module\n    def step(\n        self,\n        u: Dict[str, jnp.ndarray],\n        delta_t: float,\n        external_inds: Dict[str, jnp.ndarray],\n        externals: Dict[str, jnp.ndarray],\n        params: Dict[str, jnp.ndarray],\n        solver: str = \"bwd_euler\",\n        voltage_solver: str = \"jaxley.stone\",\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"One step of solving the Ordinary Differential Equation.\n\n        This function is called inside of `integrate` and increments the state of the\n        module by one time step. Calls `_step_channels` and `_step_synapse` to update\n        the states of the channels and synapses using fwd_euler.\n\n        Args:\n            u: The state of the module. voltages = u[\"v\"]\n            delta_t: The time step.\n            external_inds: The indices of the external inputs.\n            externals: The external inputs.\n            params: The parameters of the module.\n            solver: The solver to use for the voltages. Either of [\"bwd_euler\",\n                \"fwd_euler\", \"crank_nicolson\"].\n            voltage_solver: The tridiagonal solver used to diagonalize the\n                coefficient matrix of the ODE system. Either of [\"jaxley.thomas\",\n                \"jaxley.stone\"].\n\n        Returns:\n            The updated state of the module.\n        \"\"\"\n\n        # Extract the voltages\n        voltages = u[\"v\"]\n\n        # Extract the external inputs\n        if \"i\" in externals.keys():\n            i_current = externals[\"i\"]\n            i_inds = external_inds[\"i\"]\n            i_ext = self._get_external_input(\n                voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n            )\n        else:\n            i_ext = 0.0\n\n        # Step of the channels.\n        u, (v_terms, const_terms) = self._step_channels(\n            u, delta_t, self.channels, self.nodes, params\n        )\n\n        # Step of the synapse.\n        u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n            u,\n            self.synapses,\n            params,\n            delta_t,\n            self.edges,\n        )\n\n        # Clamp for channels and synapses.\n        for key in externals.keys():\n            if key not in [\"i\", \"v\"]:\n                u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n        # Voltage steps.\n        cm = params[\"capacitance\"]  # Abbreviation.\n\n        # Arguments used by all solvers.\n        solver_kwargs = {\n            \"voltages\": voltages,\n            \"voltage_terms\": (v_terms + syn_v_terms) / cm,\n            \"constant_terms\": (const_terms + i_ext + syn_const_terms) / cm,\n            \"axial_conductances\": params[\"axial_conductances\"],\n            \"internal_node_inds\": self._internal_node_inds,\n        }\n\n        # Add solver specific arguments.\n        if voltage_solver == \"jax.sparse\":\n            solver_kwargs.update(\n                {\n                    \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n                    \"data_inds\": self._data_inds,\n                    \"indices\": self._indices_jax_spsolve,\n                    \"indptr\": self._indptr_jax_spsolve,\n                    \"n_nodes\": self._n_nodes,\n                }\n            )\n            # Only for `bwd_euler` and `cranck-nicolson`.\n            step_voltage_implicit = step_voltage_implicit_with_jax_spsolve\n        else:\n            # Our custom sparse solver requires a different format of all conductance\n            # values to perform triangulation and backsubstution optimally.\n            #\n            # Currently, the forward Euler solver also uses this format. However,\n            # this is only for historical reasons and we are planning to change this in\n            # the future.\n            solver_kwargs.update(\n                {\n                    \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n                    \"sources\": np.asarray(self._comp_edges[\"source\"].to_list()),\n                    \"types\": np.asarray(self._comp_edges[\"type\"].to_list()),\n                    \"ncomp_per_branch\": self.ncomp_per_branch,\n                    \"par_inds\": self._par_inds,\n                    \"child_inds\": self._child_inds,\n                    \"nbranches\": self.total_nbranches,\n                    \"solver\": voltage_solver,\n                    \"idx\": self._solve_indexer,\n                    \"debug_states\": self.debug_states,\n                }\n            )\n            # Only for `bwd_euler` and `cranck-nicolson`.\n            step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve\n\n        if solver == \"bwd_euler\":\n            u[\"v\"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)\n        elif solver == \"crank_nicolson\":\n            # Crank-Nicolson advances by half a step of backward and half a step of\n            # forward Euler.\n            half_step_delta_t = delta_t / 2\n            half_step_voltages = step_voltage_implicit(\n                **solver_kwargs, delta_t=half_step_delta_t\n            )\n            # The forward Euler step in Crank-Nicolson can be performed easily as\n            # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4.\n            u[\"v\"] = 2 * half_step_voltages - voltages\n        elif solver == \"fwd_euler\":\n            u[\"v\"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t)\n        else:\n            raise ValueError(\n                f\"You specified `solver={solver}`. The only allowed solvers are \"\n                \"['bwd_euler', 'fwd_euler', 'crank_nicolson'].\"\n            )\n\n        # Clamp for voltages.\n        if \"v\" in externals.keys():\n            u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n        return u\n\n    def _step_channels(\n        self,\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"One step of integration of the channels and of computing their current.\"\"\"\n        states = self._step_channels_state(\n            states, delta_t, channels, channel_nodes, params\n        )\n        states, current_terms = self._channel_currents(\n            states, delta_t, channels, channel_nodes, params\n        )\n        return states, current_terms\n\n    def _step_channels_state(\n        self,\n        states,\n        delta_t,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Dict[str, jnp.ndarray]:\n        \"\"\"One integration step of the channels.\"\"\"\n        voltages = states[\"v\"]\n\n        # Update states of the channels.\n        indices = channel_nodes[\"global_comp_index\"].to_numpy()\n        for channel in channels:\n            channel_param_names = list(channel.channel_params)\n            channel_param_names += [\n                \"radius\",\n                \"length\",\n                \"axial_resistivity\",\n                \"capacitance\",\n            ]\n            channel_state_names = list(channel.channel_states)\n            channel_state_names += self.membrane_current_names\n            channel_indices = indices[channel_nodes[channel._name].astype(bool)]\n\n            channel_params = query_channel_states_and_params(\n                params, channel_param_names, channel_indices\n            )\n            channel_states = query_channel_states_and_params(\n                states, channel_state_names, channel_indices\n            )\n\n            states_updated = channel.update_states(\n                channel_states, delta_t, voltages[channel_indices], channel_params\n            )\n            # Rebuild state. This has to be done within the loop over channels to allow\n            # multiple channels which modify the same state.\n            for key, val in states_updated.items():\n                states[key] = states[key].at[channel_indices].set(val)\n\n        return states\n\n    def _channel_currents(\n        self,\n        states: Dict[str, jnp.ndarray],\n        delta_t: float,\n        channels: List[Channel],\n        channel_nodes: pd.DataFrame,\n        params: Dict[str, jnp.ndarray],\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Return the current through each channel.\n\n        This is also updates `state` because the `state` also contains the current.\n        \"\"\"\n        voltages = states[\"v\"]\n\n        # Compute current through channels.\n        voltage_terms = jnp.zeros_like(voltages)\n        constant_terms = jnp.zeros_like(voltages)\n        # Run with two different voltages that are `diff` apart to infer the slope and\n        # offset.\n        diff = 1e-3\n\n        current_states = {}\n        for name in self.membrane_current_names:\n            current_states[name] = jnp.zeros_like(voltages)\n\n        for channel in channels:\n            name = channel._name\n            channel_param_names = list(channel.channel_params.keys())\n            channel_state_names = list(channel.channel_states.keys())\n            indices = channel_nodes.loc[channel_nodes[name]][\n                \"global_comp_index\"\n            ].to_numpy()\n\n            channel_params = {}\n            for p in channel_param_names:\n                channel_params[p] = params[p][indices]\n            channel_params[\"radius\"] = params[\"radius\"][indices]\n            channel_params[\"length\"] = params[\"length\"][indices]\n            channel_params[\"axial_resistivity\"] = params[\"axial_resistivity\"][indices]\n\n            channel_states = {}\n            for s in channel_state_names:\n                channel_states[s] = states[s][indices]\n\n            v_and_perturbed = jnp.stack([voltages[indices], voltages[indices] + diff])\n            membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))(\n                channel_states, v_and_perturbed, channel_params\n            )\n            voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff\n            constant_term = membrane_currents[0] - voltage_term * voltages[indices]\n\n            # * 1000 to convert from mA/cm^2 to uA/cm^2.\n            voltage_terms = voltage_terms.at[indices].add(voltage_term * 1000.0)\n            constant_terms = constant_terms.at[indices].add(-constant_term * 1000.0)\n\n            # Save the current (for the unperturbed voltage) as a state that will\n            # also be passed to the state update.\n            current_states[channel.current_name] = (\n                current_states[channel.current_name]\n                .at[indices]\n                .add(membrane_currents[0])\n            )\n\n        # Copy the currents into the `state` dictionary such that they can be\n        # recorded and used by `Channel.update_states()`.\n        for name in self.membrane_current_names:\n            states[name] = current_states[name]\n\n        return states, (voltage_terms, constant_terms)\n\n    def _step_synapse(\n        self,\n        u: Dict[str, jnp.ndarray],\n        syn_channels: List[Channel],\n        params: Dict[str, jnp.ndarray],\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"One step of integration of the channels.\n\n        `Network` overrides this method (because it actually has synapses), whereas\n        `Compartment`, `Branch`, and `Cell` do not override this.\n        \"\"\"\n        voltages = u[\"v\"]\n        return u, (jnp.zeros_like(voltages), jnp.zeros_like(voltages))\n\n    def _synapse_currents(\n        self, states, syn_channels, params, delta_t, edges: pd.DataFrame\n    ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n        return states, (None, None)\n\n    @staticmethod\n    def _get_external_input(\n        voltages: jnp.ndarray,\n        i_inds: jnp.ndarray,\n        i_stim: jnp.ndarray,\n        radius: float,\n        length_single_compartment: float,\n    ) -> jnp.ndarray:\n        \"\"\"\n        Return external input to each compartment in uA / cm^2.\n\n        Args:\n            voltages: mV.\n            i_stim: nA.\n            radius: um.\n            length_single_compartment: um.\n        \"\"\"\n        zero_vec = jnp.zeros_like(voltages)\n        current = convert_point_process_to_distributed(\n            i_stim, radius[i_inds], length_single_compartment[i_inds]\n        )\n\n        dnums = ScatterDimensionNumbers(\n            update_window_dims=(),\n            inserted_window_dims=(0,),\n            scatter_dims_to_operand_dims=(0,),\n        )\n        stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)\n        return stim_at_timestep\n\n    def vis(\n        self,\n        ax: Optional[Axes] = None,\n        color: str = \"k\",\n        dims: Tuple[int] = (0, 1),\n        type: str = \"line\",\n        **kwargs,\n    ) -> Axes:\n        \"\"\"Visualize the module.\n\n        Modules can be visualized on one of the cardinal planes (xy, xz, yz) or\n        even in 3D.\n\n        Several options are available:\n        - `line`: All points from the traced morphology (`xyzr`), are connected\n        with a line plot.\n        - `scatter`: All traced points, are plotted as scatter points.\n        - `comp`: Plots the compartmentalized morphology, including radius\n        and shape. (shows the true compartment lengths per default, but this can\n        be changed via the `kwargs`, for details see\n        `jaxley.utils.plot_utils.plot_comps`).\n        - `morph`: Reconstructs the 3D shape of the traced morphology. For details see\n        `jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies\n        with many traced points this can be very slow.\n\n        Args:\n            ax: An axis into which to plot.\n            color: The color for all branches.\n            dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n                two of them.\n            type: The type of plot. One of [\"line\", \"scatter\", \"comp\", \"morph\"].\n            kwargs: Keyword arguments passed to the plotting function.\n        \"\"\"\n        res = 100 if \"resolution\" not in kwargs else kwargs.pop(\"resolution\")\n        if \"comp\" in type.lower():\n            return plot_comps(\n                self, dims=dims, ax=ax, color=color, resolution=res, **kwargs\n            )\n        if \"morph\" in type.lower():\n            return plot_morph(\n                self, dims=dims, ax=ax, color=color, resolution=res, **kwargs\n            )\n\n        assert not np.any(\n            [np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]\n        ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n\n        ax = plot_graph(\n            self.xyzr,\n            dims=dims,\n            color=color,\n            ax=ax,\n            type=type,\n            **kwargs,\n        )\n\n        return ax\n\n    def compute_xyz(self):\n        \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n        This function should not be called if the morphology was read from an `.swc`\n        file. However, for morphologies that were constructed from scratch, this\n        function **must** be called before `.vis()`. The computed `xyz` coordinates\n        are only used for plotting.\n        \"\"\"\n        max_y_multiplier = 5.0\n        min_y_multiplier = 0.5\n\n        parents = self.comb_parents\n        num_children = _compute_num_children(parents)\n        index_of_child = _compute_index_of_child(parents)\n        levels = compute_levels(parents)\n\n        # Extract branch.\n        inds_branch = self.nodes.groupby(\"global_branch_index\")[\n            \"global_comp_index\"\n        ].apply(list)\n        branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n        endpoints = []\n\n        # Different levels will get a different \"angle\" at which the children emerge from\n        # the parents. This angle is defined by the `y_offset_multiplier`. This value\n        # defines the range between y-location of the first and of the last child of a\n        # parent.\n        y_offset_multiplier = np.linspace(\n            max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n        )\n\n        for b in range(self.total_nbranches):\n            # For networks with mixed SWC and from-scatch neurons, only update those\n            # branches that do not have coordingates yet.\n            if np.any(np.isnan(self.xyzr[b])):\n                if parents[b] > -1:\n                    start_point = endpoints[parents[b]]\n                    num_children_of_parent = num_children[parents[b]]\n                    if num_children_of_parent > 1:\n                        y_offset = (\n                            ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n                        ) * y_offset_multiplier[levels[b]]\n                    else:\n                        y_offset = 0.0\n                else:\n                    start_point = [0, 0, 0]\n                    y_offset = 0.0\n\n                len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n                end_point = [\n                    start_point[0] + branch_lens[b] / len_of_path * 1.0,\n                    start_point[1] + branch_lens[b] / len_of_path * y_offset,\n                    start_point[2],\n                ]\n                endpoints.append(end_point)\n\n                self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n            else:\n                # Dummy to keey the index `endpoints[parent[b]]` above working.\n                endpoints.append(np.zeros((2,)))\n\n    def move(\n        self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = False\n    ):\n        \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n        This function is used only for visualization. It does not affect the simulation.\n\n        Args:\n            x: The amount to move in the x direction in um.\n            y: The amount to move in the y direction in um.\n            z: The amount to move in the z direction in um.\n            update_nodes: Whether `.nodes` should be updated or not. Setting this to\n                `False` largely speeds up moving, especially for big networks, but\n                `.nodes` or `.show` will not show the new xyz coordinates.\n        \"\"\"\n        for i in self._branches_in_view:\n            self.base.xyzr[i][:, :3] += np.array([x, y, z])\n        if update_nodes:\n            self.compute_compartment_centers()\n\n    def move_to(\n        self,\n        x: Union[float, np.ndarray] = 0.0,\n        y: Union[float, np.ndarray] = 0.0,\n        z: Union[float, np.ndarray] = 0.0,\n        update_nodes: bool = False,\n    ):\n        \"\"\"Move cells or networks to a location (x, y, z).\n\n        If x, y, and z are floats, then the first compartment of the first branch\n        of the first cell is moved to that float coordinate, and everything else is\n        shifted by the difference between that compartment's previous coordinate and\n        the new float location.\n\n        If x, y, and z are arrays, then they must each have a length equal to the number\n        of cells being moved. Then the first compartment of the first branch of each\n        cell is moved to the specified location.\n\n        Args:\n            update_nodes: Whether `.nodes` should be updated or not. Setting this to\n                `False` largely speeds up moving, especially for big networks, but\n                `.nodes` or `.show` will not show the new xyz coordinates.\n        \"\"\"\n        # Test if any coordinate values are NaN which would greatly affect moving\n        if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n            raise ValueError(\n                \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n            )\n\n        # can only iterate over cells for networks\n        # lambda makes sure that generator can be created multiple times\n        base_is_net = self.base._current_view == \"network\"\n        cells = lambda: (self.cells if base_is_net else [self])\n\n        root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()])\n        root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells\n        move_by = np.array([x, y, z]).T - root_xyz\n\n        if len(move_by.shape) == 1:\n            move_by = np.tile(move_by, (len(self._cells_in_view), 1))\n\n        for cell, offset in zip(cells(), move_by):\n            for idx in cell._branches_in_view:\n                self.base.xyzr[idx][:, :3] += offset\n        if update_nodes:\n            self.compute_compartment_centers()\n\n    def rotate(\n        self, degrees: float, rotation_axis: str = \"xy\", update_nodes: bool = False\n    ):\n        \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n        This function is used only for visualization. It does not affect the simulation.\n\n        Args:\n            degrees: How many degrees to rotate the module by.\n            rotation_axis: Either of {`xy` | `xz` | `yz`}.\n        \"\"\"\n        degrees = degrees / 180 * np.pi\n        if rotation_axis == \"xy\":\n            dims = [0, 1]\n        elif rotation_axis == \"xz\":\n            dims = [0, 2]\n        elif rotation_axis == \"yz\":\n            dims = [1, 2]\n        else:\n            raise ValueError\n\n        rotation_matrix = np.asarray(\n            [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]\n        )\n        for i in self._branches_in_view:\n            rot = np.dot(rotation_matrix, self.base.xyzr[i][:, dims].T).T\n            self.base.xyzr[i][:, dims] = rot\n        if update_nodes:\n            self.compute_compartment_centers()\n\n    def copy_node_property_to_edges(\n        self,\n        properties_to_import: Union[str, List[str]],\n        pre_or_post: Union[str, List[str]] = [\"pre\", \"post\"],\n    ) -> Module:\n        \"\"\"Copy a property that is in `node` over to `edges`.\n\n        By default, `.edges` does not contain the properties (radius, length, cm,\n        channel properties,...) of the pre- and post-synaptic compartments. This\n        method allows to copy a property of the pre- and/or post-synaptic compartment\n        to the edges. It is then accessible as `module.edges.pre_property_name` or\n        `module.edges.post_property_name`.\n\n        Note that, if you modify the node property _after_ having run\n        `copy_node_property_to_edges`, it will not automatically update the value in\n        `.edges`.\n\n        Note that, if this method is called on a View (e.g.\n        `net.cell(0).copy_node_property_to_edges`), then it will return a View, but\n        it will _not_ modify the module itself.\n\n        Args:\n            properties_to_import: The name of the node properties that should be\n                imported. To list all available properties, look at\n                `module.nodes.columns`.\n            pre_or_post: Whether to import only the pre-synaptic property ('pre'), only\n                the post-synaptic property ('post'), or both (['pre', 'post']).\n\n        Returns:\n            A new module which has the property copied to the `nodes`.\n        \"\"\"\n        # If a string is passed, wrap it as a list.\n        if isinstance(pre_or_post, str):\n            pre_or_post = [pre_or_post]\n        if isinstance(properties_to_import, str):\n            properties_to_import = [properties_to_import]\n\n        for pre_or_post_val in pre_or_post:\n            assert pre_or_post_val in [\"pre\", \"post\"]\n            for property_to_import in properties_to_import:\n                # Delete the column if it already exists. Otherwise it would exist\n                # twice.\n                if f\"{pre_or_post_val}_{property_to_import}\" in self.edges.columns:\n                    self.edges.drop(\n                        columns=f\"{pre_or_post_val}_{property_to_import}\", inplace=True\n                    )\n\n                self.edges = self.edges.join(\n                    self.nodes[[property_to_import, \"global_comp_index\"]].set_index(\n                        \"global_comp_index\"\n                    ),\n                    on=f\"{pre_or_post_val}_global_comp_index\",\n                )\n                self.edges = self.edges.rename(\n                    columns={\n                        property_to_import: f\"{pre_or_post_val}_{property_to_import}\"\n                    }\n                )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.branches","title":"branches property","text":"

Iterate over all branches in the module.

Returns a generator that yields a View of each branch.

"},{"location":"reference/modules/#jaxley.modules.base.Module.cells","title":"cells property","text":"

Iterate over all cells in the module.

Returns a generator that yields a View of each cell.

"},{"location":"reference/modules/#jaxley.modules.base.Module.comps","title":"comps property","text":"

Iterate over all compartments in the module. Can be called on any module, i.e. net.comps, cell.comps or branch.comps. __iter__ does not allow for this.

Returns a generator that yields a View of each compartment.

"},{"location":"reference/modules/#jaxley.modules.base.Module.initialized","title":"initialized: bool property","text":"

Whether the Module is ready to be solved or not.

"},{"location":"reference/modules/#jaxley.modules.base.Module.shape","title":"shape: Tuple[int] property","text":"

Returns the number of submodules contained in a module.

.. code-block:: python

network.shape = (num_cells, num_branches, num_compartments)\ncell.shape = (num_branches, num_compartments)\nbranch.shape = (num_compartments,)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.view","title":"view property","text":"

Return view of the module.

"},{"location":"reference/modules/#jaxley.modules.base.Module.__getitem__","title":"__getitem__(index)","text":"

Lazy indexing of the module.

Source code in jaxley/modules/base.py
def __getitem__(self, index):\n    \"\"\"Lazy indexing of the module.\"\"\"\n    supported_parents = [\"network\", \"cell\", \"branch\"]  # cannot index into comp\n\n    not_group_view = self._current_view not in self.groups\n    assert (\n        self._current_view in supported_parents or not_group_view\n    ), \"Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof.\"\n    index = index if isinstance(index, tuple) else (index,)\n\n    child_views = self._childviews()\n    assert len(index) <= len(child_views), \"Too many indices.\"\n    view = self\n    for i, child in zip(index, child_views):\n        view = view._at_nodes(child, i)\n    return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.__iter__","title":"__iter__()","text":"

Iterate over parts of the module.

Internally calls cells, branches, comps at the appropriate level.

Example:

.. code-block:: python

for cell in network:\n    for branch in cell:\n        for comp in branch:\n            print(comp.nodes.shape)\n
Source code in jaxley/modules/base.py
def __iter__(self):\n    \"\"\"Iterate over parts of the module.\n\n    Internally calls `cells`, `branches`, `comps` at the appropriate level.\n\n    Example:\n\n    .. code-block:: python\n\n        for cell in network:\n            for branch in cell:\n                for comp in branch:\n                    print(comp.nodes.shape)\n    \"\"\"\n    next_level = self._childviews()[0]\n    yield from self._iter_submodules(next_level)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.add_to_group","title":"add_to_group(group_name)","text":"

Add a view of the module to a group.

Groups can then be indexed. For example:

.. code-block:: python

net.cell(0).add_to_group(\"excitatory\")\nnet.excitatory.set(\"radius\", 0.1)\n

Parameters:

Name Type Description Default group_name str

The name of the group.

required Source code in jaxley/modules/base.py
def add_to_group(self, group_name: str):\n    \"\"\"Add a view of the module to a group.\n\n    Groups can then be indexed. For example:\n\n    .. code-block:: python\n\n        net.cell(0).add_to_group(\"excitatory\")\n        net.excitatory.set(\"radius\", 0.1)\n\n    Args:\n        group_name: The name of the group.\n    \"\"\"\n    if group_name not in self.base.groups:\n        self.base.groups[group_name] = self._nodes_in_view\n    else:\n        self.base.groups[group_name] = np.unique(\n            np.concatenate([self.base.groups[group_name], self._nodes_in_view])\n        )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.branch","title":"branch(idx)","text":"

Return a View of the module at the selected branches(s).

Parameters:

Name Type Description Default idx Any

index of the branch to view.

required

Returns:

Type Description View

View of the module at the specified branch index.

Source code in jaxley/modules/base.py
def branch(self, idx: Any) -> View:\n    \"\"\"Return a View of the module at the selected branches(s).\n\n    Args:\n        idx: index of the branch to view.\n\n    Returns:\n        View of the module at the specified branch index.\"\"\"\n    return self._at_nodes(\"branch\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.cell","title":"cell(idx)","text":"

Return a View of the module at the selected cell(s).

Parameters:

Name Type Description Default idx Any

index of the cell to view.

required

Returns:

Type Description View

View of the module at the specified cell index.

Source code in jaxley/modules/base.py
def cell(self, idx: Any) -> View:\n    \"\"\"Return a View of the module at the selected cell(s).\n\n    Args:\n        idx: index of the cell to view.\n\n    Returns:\n        View of the module at the specified cell index.\"\"\"\n    return self._at_nodes(\"cell\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.clamp","title":"clamp(state_name, state_array, verbose=True)","text":"

Clamp a state to a given value across specified compartments.

Parameters:

Name Type Description Default state_name str

The name of the state to clamp.

required state_array nd

Array of values to clamp the state to.

required verbose

If True, prints details about the clamping.

True

This function sets external states for the compartments.

Source code in jaxley/modules/base.py
def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n    \"\"\"Clamp a state to a given value across specified compartments.\n\n    Args:\n        state_name: The name of the state to clamp.\n        state_array (jnp.nd: Array of values to clamp the state to.\n        verbose : If True, prints details about the clamping.\n\n    This function sets external states for the compartments.\n    \"\"\"\n    self._external_input(state_name, state_array, verbose=verbose)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.comp","title":"comp(idx)","text":"

Return a View of the module at the selected compartments(s).

Parameters:

Name Type Description Default idx Any

index of the comp to view.

required

Returns:

Type Description View

View of the module at the specified compartment index.

Source code in jaxley/modules/base.py
def comp(self, idx: Any) -> View:\n    \"\"\"Return a View of the module at the selected compartments(s).\n\n    Args:\n        idx: index of the comp to view.\n\n    Returns:\n        View of the module at the specified compartment index.\"\"\"\n    return self._at_nodes(\"comp\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.compute_compartment_centers","title":"compute_compartment_centers()","text":"

Add compartment centers to nodes dataframe

Source code in jaxley/modules/base.py
def compute_compartment_centers(self):\n    \"\"\"Add compartment centers to nodes dataframe\"\"\"\n    centers = self._compute_coords_of_comp_centers()\n    self.base.nodes.loc[self._nodes_in_view, [\"x\", \"y\", \"z\"]] = centers\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.compute_xyz","title":"compute_xyz()","text":"

Return xyz coordinates of every branch, based on the branch length.

This function should not be called if the morphology was read from an .swc file. However, for morphologies that were constructed from scratch, this function must be called before .vis(). The computed xyz coordinates are only used for plotting.

Source code in jaxley/modules/base.py
def compute_xyz(self):\n    \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n    This function should not be called if the morphology was read from an `.swc`\n    file. However, for morphologies that were constructed from scratch, this\n    function **must** be called before `.vis()`. The computed `xyz` coordinates\n    are only used for plotting.\n    \"\"\"\n    max_y_multiplier = 5.0\n    min_y_multiplier = 0.5\n\n    parents = self.comb_parents\n    num_children = _compute_num_children(parents)\n    index_of_child = _compute_index_of_child(parents)\n    levels = compute_levels(parents)\n\n    # Extract branch.\n    inds_branch = self.nodes.groupby(\"global_branch_index\")[\n        \"global_comp_index\"\n    ].apply(list)\n    branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n    endpoints = []\n\n    # Different levels will get a different \"angle\" at which the children emerge from\n    # the parents. This angle is defined by the `y_offset_multiplier`. This value\n    # defines the range between y-location of the first and of the last child of a\n    # parent.\n    y_offset_multiplier = np.linspace(\n        max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n    )\n\n    for b in range(self.total_nbranches):\n        # For networks with mixed SWC and from-scatch neurons, only update those\n        # branches that do not have coordingates yet.\n        if np.any(np.isnan(self.xyzr[b])):\n            if parents[b] > -1:\n                start_point = endpoints[parents[b]]\n                num_children_of_parent = num_children[parents[b]]\n                if num_children_of_parent > 1:\n                    y_offset = (\n                        ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n                    ) * y_offset_multiplier[levels[b]]\n                else:\n                    y_offset = 0.0\n            else:\n                start_point = [0, 0, 0]\n                y_offset = 0.0\n\n            len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n            end_point = [\n                start_point[0] + branch_lens[b] / len_of_path * 1.0,\n                start_point[1] + branch_lens[b] / len_of_path * y_offset,\n                start_point[2],\n            ]\n            endpoints.append(end_point)\n\n            self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n        else:\n            # Dummy to keey the index `endpoints[parent[b]]` above working.\n            endpoints.append(np.zeros((2,)))\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.copy","title":"copy(reset_index=False, as_module=False)","text":"

Extract part of a module and return a copy of its View or a new module.

This can be used to call jx.integrate on part of a Module.

Parameters:

Name Type Description Default reset_index bool

if True, the indices of the new module are reset to start from 0.

False as_module bool

if True, a new module is returned instead of a View.

False

Returns:

Type Description Union[Module, View]

A part of the module or a copied view of it.

Source code in jaxley/modules/base.py
def copy(\n    self, reset_index: bool = False, as_module: bool = False\n) -> Union[Module, View]:\n    \"\"\"Extract part of a module and return a copy of its View or a new module.\n\n    This can be used to call `jx.integrate` on part of a Module.\n\n    Args:\n        reset_index: if True, the indices of the new module are reset to start from 0.\n        as_module: if True, a new module is returned instead of a View.\n\n    Returns:\n        A part of the module or a copied view of it.\"\"\"\n    view = deepcopy(self)\n    warnings.warn(\"This method is experimental, use at your own risk.\")\n    # TODO FROM #447: add reset_index, i.e. for parents, nodes, edges etc. such that they\n    # start from 0/-1 and are contiguous\n    if as_module:\n        raise NotImplementedError(\"Not yet implemented.\")\n        # initialize a new module with the same attributes\n    return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.copy_node_property_to_edges","title":"copy_node_property_to_edges(properties_to_import, pre_or_post=['pre', 'post'])","text":"

Copy a property that is in node over to edges.

By default, .edges does not contain the properties (radius, length, cm, channel properties,\u2026) of the pre- and post-synaptic compartments. This method allows to copy a property of the pre- and/or post-synaptic compartment to the edges. It is then accessible as module.edges.pre_property_name or module.edges.post_property_name.

Note that, if you modify the node property after having run copy_node_property_to_edges, it will not automatically update the value in .edges.

Note that, if this method is called on a View (e.g. net.cell(0).copy_node_property_to_edges), then it will return a View, but it will not modify the module itself.

Parameters:

Name Type Description Default properties_to_import Union[str, List[str]]

The name of the node properties that should be imported. To list all available properties, look at module.nodes.columns.

required pre_or_post Union[str, List[str]]

Whether to import only the pre-synaptic property (\u2018pre\u2019), only the post-synaptic property (\u2018post\u2019), or both ([\u2018pre\u2019, \u2018post\u2019]).

['pre', 'post']

Returns:

Type Description Module

A new module which has the property copied to the nodes.

Source code in jaxley/modules/base.py
def copy_node_property_to_edges(\n    self,\n    properties_to_import: Union[str, List[str]],\n    pre_or_post: Union[str, List[str]] = [\"pre\", \"post\"],\n) -> Module:\n    \"\"\"Copy a property that is in `node` over to `edges`.\n\n    By default, `.edges` does not contain the properties (radius, length, cm,\n    channel properties,...) of the pre- and post-synaptic compartments. This\n    method allows to copy a property of the pre- and/or post-synaptic compartment\n    to the edges. It is then accessible as `module.edges.pre_property_name` or\n    `module.edges.post_property_name`.\n\n    Note that, if you modify the node property _after_ having run\n    `copy_node_property_to_edges`, it will not automatically update the value in\n    `.edges`.\n\n    Note that, if this method is called on a View (e.g.\n    `net.cell(0).copy_node_property_to_edges`), then it will return a View, but\n    it will _not_ modify the module itself.\n\n    Args:\n        properties_to_import: The name of the node properties that should be\n            imported. To list all available properties, look at\n            `module.nodes.columns`.\n        pre_or_post: Whether to import only the pre-synaptic property ('pre'), only\n            the post-synaptic property ('post'), or both (['pre', 'post']).\n\n    Returns:\n        A new module which has the property copied to the `nodes`.\n    \"\"\"\n    # If a string is passed, wrap it as a list.\n    if isinstance(pre_or_post, str):\n        pre_or_post = [pre_or_post]\n    if isinstance(properties_to_import, str):\n        properties_to_import = [properties_to_import]\n\n    for pre_or_post_val in pre_or_post:\n        assert pre_or_post_val in [\"pre\", \"post\"]\n        for property_to_import in properties_to_import:\n            # Delete the column if it already exists. Otherwise it would exist\n            # twice.\n            if f\"{pre_or_post_val}_{property_to_import}\" in self.edges.columns:\n                self.edges.drop(\n                    columns=f\"{pre_or_post_val}_{property_to_import}\", inplace=True\n                )\n\n            self.edges = self.edges.join(\n                self.nodes[[property_to_import, \"global_comp_index\"]].set_index(\n                    \"global_comp_index\"\n                ),\n                on=f\"{pre_or_post_val}_global_comp_index\",\n            )\n            self.edges = self.edges.rename(\n                columns={\n                    property_to_import: f\"{pre_or_post_val}_{property_to_import}\"\n                }\n            )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_clamp","title":"data_clamp(state_name, state_array, data_clamps=None, verbose=False)","text":"

Insert a clamp into the module within jit (or grad).

Parameters:

Name Type Description Default state_name str

Name of the state variable to set.

required state_array ndarray

Time series of the state variable in the default Jaxley unit. State array should be of shape (num_clamps, simulation_time) or (simulation_time, ) for a single clamp.

required verbose bool

Whether or not to print the number of inserted clamps. False by default because this method is meant to be jitted.

False Source code in jaxley/modules/base.py
def data_clamp(\n    self,\n    state_name: str,\n    state_array: jnp.ndarray,\n    data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n    verbose: bool = False,\n):\n    \"\"\"Insert a clamp into the module within jit (or grad).\n\n    Args:\n        state_name: Name of the state variable to set.\n        state_array: Time series of the state variable in the default Jaxley unit.\n            State array should be of shape (num_clamps, simulation_time) or\n            (simulation_time, ) for a single clamp.\n        verbose: Whether or not to print the number of inserted clamps. `False`\n            by default because this method is meant to be jitted.\n    \"\"\"\n    comp_states, edge_states = self._get_state_names()\n    if state_name not in comp_states + edge_states:\n        raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n    data = self.nodes if state_name in comp_states else self.edges\n    return self._data_external_input(\n        state_name, state_array, data_clamps, data, verbose=verbose\n    )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_set","title":"data_set(key, val, param_state)","text":"

Set parameter of module (or its view) to a new value within jit.

Parameters:

Name Type Description Default key str

The name of the parameter to set.

required val Union[float, ndarray]

The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).

required param_state Optional[List[Dict]]

State of the setted parameters, internally used such that this function does not modify global state.

required Source code in jaxley/modules/base.py
def data_set(\n    self,\n    key: str,\n    val: Union[float, jnp.ndarray],\n    param_state: Optional[List[Dict]],\n):\n    \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n    Args:\n        key: The name of the parameter to set.\n        val: The value to set the parameter to. If it is `jnp.ndarray` then it\n            must be of shape `(len(num_compartments))`.\n        param_state: State of the setted parameters, internally used such that this\n            function does not modify global state.\n    \"\"\"\n    # Note: `data_set` does not support arrays for `val`.\n    is_node_param = key in self.nodes.columns\n    data = self.nodes if is_node_param else self.edges\n    viewed_inds = self._nodes_in_view if is_node_param else self._edges_in_view\n    if key in data.columns:\n        not_nan = ~data[key].isna()\n        added_param_state = [\n            {\n                \"indices\": np.atleast_2d(viewed_inds[not_nan]),\n                \"key\": key,\n                \"val\": jnp.atleast_1d(jnp.asarray(val)),\n            }\n        ]\n        if param_state is not None:\n            param_state += added_param_state\n        else:\n            param_state = added_param_state\n    else:\n        raise KeyError(\"Key not recognized.\")\n    return param_state\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_stimulate","title":"data_stimulate(current, data_stimuli=None, verbose=False)","text":"

Insert a stimulus into the module within jit (or grad).

Parameters:

Name Type Description Default current ndarray

Current in nA.

required verbose bool

Whether or not to print the number of inserted stimuli. False by default because this method is meant to be jitted.

False Source code in jaxley/modules/base.py
def data_stimulate(\n    self,\n    current: jnp.ndarray,\n    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n    verbose: bool = False,\n) -> Tuple[jnp.ndarray, pd.DataFrame]:\n    \"\"\"Insert a stimulus into the module within jit (or grad).\n\n    Args:\n        current: Current in `nA`.\n        verbose: Whether or not to print the number of inserted stimuli. `False`\n            by default because this method is meant to be jitted.\n    \"\"\"\n    return self._data_external_input(\n        \"i\", current, data_stimuli, self.nodes, verbose=verbose\n    )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_channel","title":"delete_channel(channel)","text":"

Remove a channel from the module.

Parameters:

Name Type Description Default channel Channel

The channel to remove.

required Source code in jaxley/modules/base.py
def delete_channel(self, channel: Channel):\n    \"\"\"Remove a channel from the module.\n\n    Args:\n        channel: The channel to remove.\"\"\"\n    name = channel._name\n    channel_names = [c._name for c in self.channels]\n    all_channel_names = [c._name for c in self.base.channels]\n    if name in channel_names:\n        channel_cols = list(channel.channel_params.keys())\n        channel_cols += list(channel.channel_states.keys())\n        self.base.nodes.loc[self._nodes_in_view, channel_cols] = float(\"nan\")\n        self.base.nodes.loc[self._nodes_in_view, name] = False\n\n        # only delete cols if no other comps in the module have the same channel\n        if np.all(~self.base.nodes[name]):\n            self.base.channels.pop(all_channel_names.index(name))\n            self.base.membrane_current_names.remove(channel.current_name)\n            self.base.nodes.drop(columns=channel_cols + [name], inplace=True)\n    else:\n        raise ValueError(f\"Channel {name} not found in the module.\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_clamps","title":"delete_clamps(state_name=None)","text":"

Removes all clamps of the given state from the module.

Source code in jaxley/modules/base.py
def delete_clamps(self, state_name: Optional[str] = None):\n    \"\"\"Removes all clamps of the given state from the module.\"\"\"\n    all_externals = list(self.externals.keys())\n    if \"i\" in all_externals:\n        all_externals.remove(\"i\")\n    state_names = all_externals if state_name is None else [state_name]\n    for state_name in state_names:\n        if state_name in self.externals:\n            keep_inds = ~np.isin(\n                self.base.external_inds[state_name], self._nodes_in_view\n            )\n            base_exts = self.base.externals\n            base_exts_inds = self.base.external_inds\n            if np.all(~keep_inds):\n                base_exts.pop(state_name, None)\n                base_exts_inds.pop(state_name, None)\n            else:\n                base_exts[state_name] = base_exts[state_name][keep_inds]\n                base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]\n            self._update_view()\n        else:\n            pass  # does not have to be deleted if not in externals\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_recordings","title":"delete_recordings()","text":"

Removes all recordings from the module.

Source code in jaxley/modules/base.py
def delete_recordings(self):\n    \"\"\"Removes all recordings from the module.\"\"\"\n    if isinstance(self, View):\n        base_recs = self.base.recordings\n        self.base.recordings = base_recs[\n            ~base_recs.isin(self.recordings).all(axis=1)\n        ]\n        self._update_view()\n    else:\n        self.base.recordings = pd.DataFrame().from_dict({})\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_stimuli","title":"delete_stimuli()","text":"

Removes all stimuli from the module.

Source code in jaxley/modules/base.py
def delete_stimuli(self):\n    \"\"\"Removes all stimuli from the module.\"\"\"\n    self.delete_clamps(\"i\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_trainables","title":"delete_trainables()","text":"

Removes all trainable parameters from the module.

Source code in jaxley/modules/base.py
def delete_trainables(self):\n    \"\"\"Removes all trainable parameters from the module.\"\"\"\n\n    if isinstance(self, View):\n        trainables_and_inds = self._filter_trainables(is_viewed=False)\n        self.base.indices_set_by_trainables = trainables_and_inds[0]\n        self.base.trainable_params = trainables_and_inds[1]\n        self.base.num_trainable_params -= self.num_trainable_params\n    else:\n        self.base.indices_set_by_trainables = []\n        self.base.trainable_params = []\n        self.base.num_trainable_params = 0\n    self._update_view()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.distance","title":"distance(endpoint)","text":"

Return the direct distance between two compartments. This does not compute the pathwise distance (which is currently not implemented). Args: endpoint: The compartment to which to compute the distance to.

Source code in jaxley/modules/base.py
def distance(self, endpoint: \"View\") -> float:\n    \"\"\"Return the direct distance between two compartments.\n    This does not compute the pathwise distance (which is currently not\n    implemented).\n    Args:\n        endpoint: The compartment to which to compute the distance to.\n    \"\"\"\n    assert len(self.xyzr) == 1 and len(endpoint.xyzr) == 1\n    start_xyz = np.mean(self.xyzr[0][:, :3], axis=0)\n    end_xyz = np.mean(endpoint.xyzr[0][:, :3], axis=0)\n    return np.sqrt(np.sum((start_xyz - end_xyz) ** 2))\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.edge","title":"edge(idx)","text":"

Return a View of the module at the selected synapse edges(s).

Parameters:

Name Type Description Default idx Any

index of the edge to view.

required

Returns:

Type Description View

View of the module at the specified edge index.

Source code in jaxley/modules/base.py
def edge(self, idx: Any) -> View:\n    \"\"\"Return a View of the module at the selected synapse edges(s).\n\n    Args:\n        idx: index of the edge to view.\n\n    Returns:\n        View of the module at the specified edge index.\"\"\"\n    return self._at_edges(\"edge\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_parameters","title":"get_all_parameters(pstate, voltage_solver)","text":"

Return all parameters (and coupling conductances) needed to simulate.

Runs _compute_axial_conductances() and return every parameter that is needed to solve the ODE. This includes conductances, radiuses, lengths, axial_resistivities, but also coupling conductances.

This is done by first obtaining the current value of every parameter (not only the trainable ones) and then replacing the trainable ones with the value in trainable_params(). This function is run within jx.integrate().

pstate can be obtained by calling params_to_pstate().

.. code-block:: python

params = module.get_parameters() # i.e. [0, 1, 2]\npstate = params_to_pstate(params, module.indices_set_by_trainables)\nmodule.to_jax() # needed for call to module.jaxnodes\n

Parameters:

Name Type Description Default pstate List[Dict]

The state of the trainable parameters. pstate takes the form [{ \u201ckey\u201d: \u201cgNa\u201d, \u201cindices\u201d: jnp.array([0, 1, 2]), \u201cval\u201d: jnp.array([0.1, 0.2, 0.3]) }, \u2026].

required voltage_solver str

The voltage solver that is used. Since jax.sparse and jaxley.xyz require different formats of the axial conductances, this function will default to different building methods.

required

Returns:

Type Description Dict[str, ndarray]

A dictionary of all module parameters.

Source code in jaxley/modules/base.py
@only_allow_module\ndef get_all_parameters(\n    self, pstate: List[Dict], voltage_solver: str\n) -> Dict[str, jnp.ndarray]:\n    # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n    \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n    Runs `_compute_axial_conductances()` and return every parameter that is needed\n    to solve the ODE. This includes conductances, radiuses, lengths,\n    axial_resistivities, but also coupling conductances.\n\n    This is done by first obtaining the current value of every parameter (not only\n    the trainable ones) and then replacing the trainable ones with the value\n    in `trainable_params()`. This function is run within `jx.integrate()`.\n\n    pstate can be obtained by calling `params_to_pstate()`.\n\n    .. code-block:: python\n\n        params = module.get_parameters() # i.e. [0, 1, 2]\n        pstate = params_to_pstate(params, module.indices_set_by_trainables)\n        module.to_jax() # needed for call to module.jaxnodes\n\n    Args:\n        pstate: The state of the trainable parameters. pstate takes the form\n            [{\n                \"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]),\n                \"val\": jnp.array([0.1, 0.2, 0.3])\n            }, ...].\n        voltage_solver: The voltage solver that is used. Since `jax.sparse` and\n            `jaxley.xyz` require different formats of the axial conductances, this\n            function will default to different building methods.\n\n    Returns:\n        A dictionary of all module parameters.\n    \"\"\"\n    params = {}\n    for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n        params[key] = self.base.jaxnodes[key]\n\n    for channel in self.base.channels:\n        for channel_params in channel.channel_params:\n            params[channel_params] = self.base.jaxnodes[channel_params]\n\n    for synapse_params in self.base.synapse_param_names:\n        params[synapse_params] = self.base.jaxedges[synapse_params]\n\n    # Override with those parameters set by `.make_trainable()`.\n    for parameter in pstate:\n        key = parameter[\"key\"]\n        inds = parameter[\"indices\"]\n        set_param = parameter[\"val\"]\n\n        # This is needed since SynapseViews worked differently before.\n        # This mimics the old behaviour and tranformes the new indices\n        # to the old indices.\n        # TODO FROM #447: Longterm this should be gotten rid of.\n        # Instead edges should work similar to nodes (would also allow for\n        # param sharing).\n        synapse_inds = self.base.edges.groupby(\"type\").rank()[\"global_edge_index\"]\n        synapse_inds = (synapse_inds.astype(int) - 1).to_numpy()\n        if key in self.base.synapse_param_names:\n            inds = synapse_inds[inds]\n\n        if key in params:  # Only parameters, not initial states.\n            # `inds` is of shape `(num_params, num_comps_per_param)`.\n            # `set_param` is of shape `(num_params,)`\n            # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n            # `.set()` to work. This is done with `[:, None]`.\n            params[key] = params[key].at[inds].set(set_param[:, None])\n\n    # Compute conductance params and add them to the params dictionary.\n    params[\"axial_conductances\"] = self.base._compute_axial_conductances(\n        params=params\n    )\n    return params\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_states","title":"get_all_states(pstate, all_params, delta_t)","text":"

Get the full initial state of the module from jaxnodes and trainables.

Parameters:

Name Type Description Default pstate List[Dict]

The state of the trainable parameters.

required all_params

All parameters of the module.

required delta_t float

The time step.

required

Returns:

Type Description Dict[str, ndarray]

A dictionary of all states of the module.

Source code in jaxley/modules/base.py
@only_allow_module\ndef get_all_states(\n    self, pstate: List[Dict], all_params, delta_t: float\n) -> Dict[str, jnp.ndarray]:\n    # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n    \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n    Args:\n        pstate: The state of the trainable parameters.\n        all_params: All parameters of the module.\n        delta_t: The time step.\n\n    Returns:\n        A dictionary of all states of the module.\n    \"\"\"\n    states = self.base._get_states_from_nodes_and_edges()\n\n    # Override with the initial states set by `.make_trainable()`.\n    for parameter in pstate:\n        key = parameter[\"key\"]\n        inds = parameter[\"indices\"]\n        set_param = parameter[\"val\"]\n        if key in states:  # Only initial states, not parameters.\n            # `inds` is of shape `(num_params, num_comps_per_param)`.\n            # `set_param` is of shape `(num_params,)`\n            # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n            # `.set()` to work. This is done with `[:, None]`.\n            states[key] = states[key].at[inds].set(set_param[:, None])\n\n    # Add to the states the initial current through every channel.\n    states, _ = self.base._channel_currents(\n        states, delta_t, self.channels, self.nodes, all_params\n    )\n\n    # Add to the states the initial current through every synapse.\n    states, _ = self.base._synapse_currents(\n        states, self.synapses, all_params, delta_t, self.edges\n    )\n    return states\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_parameters","title":"get_parameters()","text":"

Get all trainable parameters.

The returned parameters should be passed to `jx.integrate(\u2026, params=params).

Returns:

Type Description List[Dict[str, ndarray]]

A list of all trainable parameters in the form of [{\u201cgNa\u201d: jnp.array([0.1, 0.2, 0.3])}, \u2026].

Source code in jaxley/modules/base.py
def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n    \"\"\"Get all trainable parameters.\n\n    The returned parameters should be passed to `jx.integrate(..., params=params).\n\n    Returns:\n        A list of all trainable parameters in the form of\n            [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n    \"\"\"\n    return self.trainable_params\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.init_states","title":"init_states(delta_t=0.025)","text":"

Initialize all mechanisms in their steady state.

This considers the voltages and parameters of each compartment.

Parameters:

Name Type Description Default delta_t float

Passed on to channel.init_state().

0.025 Source code in jaxley/modules/base.py
@only_allow_module\ndef init_states(self, delta_t: float = 0.025):\n    # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n    \"\"\"Initialize all mechanisms in their steady state.\n\n    This considers the voltages and parameters of each compartment.\n\n    Args:\n        delta_t: Passed on to `channel.init_state()`.\n    \"\"\"\n    # Update states of the channels.\n    channel_nodes = self.base.nodes\n    states = self.base._get_states_from_nodes_and_edges()\n\n    # We do not use any `pstate` for initializing. In principle, we could change\n    # that by allowing an input `params` and `pstate` to this function.\n    # `voltage_solver` could also be `jax.sparse` here, because both of them\n    # build the channel parameters in the same way.\n    params = self.base.get_all_parameters([], voltage_solver=\"jaxley.thomas\")\n\n    for channel in self.base.channels:\n        name = channel._name\n        channel_indices = channel_nodes.loc[channel_nodes[name]][\n            \"global_comp_index\"\n        ].to_numpy()\n        voltages = channel_nodes.loc[channel_indices, \"v\"].to_numpy()\n\n        channel_param_names = list(channel.channel_params.keys())\n        channel_state_names = list(channel.channel_states.keys())\n        channel_states = query_channel_states_and_params(\n            states, channel_state_names, channel_indices\n        )\n        channel_params = query_channel_states_and_params(\n            params, channel_param_names, channel_indices\n        )\n\n        init_state = channel.init_state(\n            channel_states, voltages, channel_params, delta_t\n        )\n\n        # `init_state` might not return all channel states. Only the ones that are\n        # returned are updated here.\n        for key, val in init_state.items():\n            # Note that we are overriding `self.nodes` here, but `self.nodes` is\n            # not used above to actually compute the current states (so there are\n            # no issues with overriding states).\n            self.nodes.loc[channel_indices, key] = val\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.insert","title":"insert(channel)","text":"

Insert a channel into the module.

Parameters:

Name Type Description Default channel Channel

The channel to insert.

required Source code in jaxley/modules/base.py
def insert(self, channel: Channel):\n    \"\"\"Insert a channel into the module.\n\n    Args:\n        channel: The channel to insert.\"\"\"\n    name = channel._name\n\n    # Channel does not yet exist in the `jx.Module` at all.\n    if name not in [c._name for c in self.base.channels]:\n        self.base.channels.append(channel)\n        self.base.nodes[name] = (\n            False  # Previous columns do not have the new channel.\n        )\n\n    if channel.current_name not in self.base.membrane_current_names:\n        self.base.membrane_current_names.append(channel.current_name)\n\n    # Add a binary column that indicates if a channel is present.\n    self.base.nodes.loc[self._nodes_in_view, name] = True\n\n    # Loop over all new parameters, e.g. gNa, eNa.\n    for key in channel.channel_params:\n        self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key]\n\n    # Loop over all new parameters, e.g. gNa, eNa.\n    for key in channel.channel_states:\n        self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.loc","title":"loc(at)","text":"

Return a View of the module at the selected branch location(s).

Parameters:

Name Type Description Default at Any

location along the branch.

required

Returns:

Type Description View

View of the module at the specified branch location.

Source code in jaxley/modules/base.py
def loc(self, at: Any) -> View:\n    \"\"\"Return a View of the module at the selected branch location(s).\n\n    Args:\n        at: location along the branch.\n\n    Returns:\n        View of the module at the specified branch location.\"\"\"\n    global_comp_idxs = []\n    for i in self._branches_in_view:\n        ncomp = self.base.ncomp_per_branch[i]\n        comp_locs = np.linspace(0, 1, ncomp)\n        at = comp_locs if is_str_all(at) else self._reformat_index(at, dtype=float)\n        comp_edges = np.linspace(0, 1 + 1e-10, ncomp + 1)\n        idx = np.digitize(at, comp_edges) - 1 + self.base.cumsum_ncomp[i]\n        global_comp_idxs.append(idx)\n    global_comp_idxs = np.concatenate(global_comp_idxs)\n    orig_scope = self._scope\n    # global scope needed to select correct comps, for i.e. branches w. ncomp=[1,2]\n    # loc(0.9)  will correspond to different local branches (0 vs 1).\n    view = self.scope(\"global\").comp(global_comp_idxs).scope(orig_scope)\n    view._current_view = \"loc\"\n    return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.make_trainable","title":"make_trainable(key, init_val=None, verbose=True)","text":"

Make a parameter trainable.

If a parameter is made trainable, it will be returned by get_parameters() and should then be passed to jx.integrate(..., params=params).

Parameters:

Name Type Description Default key str

Name of the parameter to make trainable.

required init_val Optional[Union[float, list]]

Initial value of the parameter. If float, the same value is used for every created parameter. If list, the length of the list has to match the number of created parameters. If None, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters.

None verbose bool

Whether to print the number of parameters that are added and the total number of parameters.

True Source code in jaxley/modules/base.py
def make_trainable(\n    self,\n    key: str,\n    init_val: Optional[Union[float, list]] = None,\n    verbose: bool = True,\n):\n    \"\"\"Make a parameter trainable.\n\n    If a parameter is made trainable, it will be returned by `get_parameters()`\n    and should then be passed to `jx.integrate(..., params=params)`.\n\n    Args:\n        key: Name of the parameter to make trainable.\n        init_val: Initial value of the parameter. If `float`, the same value is\n            used for every created parameter. If `list`, the length of the list has\n            to match the number of created parameters. If `None`, the current\n            parameter value is used and if parameter sharing is performed that the\n            current parameter value is averaged over all shared parameters.\n        verbose: Whether to print the number of parameters that are added and the\n            total number of parameters.\n    \"\"\"\n    assert (\n        self.allow_make_trainable\n    ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n    ncomps_per_branch = (\n        self.base.nodes[\"global_branch_index\"].value_counts().to_numpy()\n    )\n    assert np.all(\n        ncomps_per_branch == ncomps_per_branch[0]\n    ), \"Parameter sharing is not allowed for modules containing branches with different numbers of compartments.\"\n\n    data = self.nodes if key in self.nodes.columns else None\n    data = self.edges if key in self.edges.columns else data\n\n    assert data is not None, f\"Key '{key}' not found in nodes or edges\"\n    not_nan = ~data[key].isna()\n    data = data.loc[not_nan]\n    assert (\n        len(data) > 0\n    ), \"No settable parameters found in the selected compartments.\"\n\n    grouped_view = data.groupby(\"controlled_by_param\")\n    # Because of this `x.index.values` we cannot support `make_trainable()` on\n    # the module level for synapse parameters (but only for `SynapseView`).\n    inds_of_comps = list(\n        grouped_view.apply(lambda x: x.index.values, include_groups=False)\n    )\n    indices_per_param = jnp.stack(inds_of_comps)\n    # Sorted inds are only used to infer the correct starting values.\n    param_vals = jnp.asarray(\n        [data.loc[inds, key].to_numpy() for inds in inds_of_comps]\n    )\n\n    # Set the value which the trainable parameter should take.\n    num_created_parameters = len(indices_per_param)\n    if init_val is not None:\n        if isinstance(init_val, float):\n            new_params = jnp.asarray([init_val] * num_created_parameters)\n        elif isinstance(init_val, list):\n            assert (\n                len(init_val) == num_created_parameters\n            ), f\"len(init_val)={len(init_val)}, but trying to create {num_created_parameters} parameters.\"\n            new_params = jnp.asarray(init_val)\n        else:\n            raise ValueError(\n                f\"init_val must a float, list, or None, but it is a {type(init_val).__name__}.\"\n            )\n    else:\n        new_params = jnp.mean(param_vals, axis=1)\n    self.base.trainable_params.append({key: new_params})\n    self.base.indices_set_by_trainables.append(indices_per_param)\n    self.base.num_trainable_params += num_created_parameters\n    if verbose:\n        print(\n            f\"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}\"\n        )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.move","title":"move(x=0.0, y=0.0, z=0.0, update_nodes=False)","text":"

Move cells or networks by adding to their (x, y, z) coordinates.

This function is used only for visualization. It does not affect the simulation.

Parameters:

Name Type Description Default x float

The amount to move in the x direction in um.

0.0 y float

The amount to move in the y direction in um.

0.0 z float

The amount to move in the z direction in um.

0.0 update_nodes bool

Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

False Source code in jaxley/modules/base.py
def move(\n    self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = False\n):\n    \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n    This function is used only for visualization. It does not affect the simulation.\n\n    Args:\n        x: The amount to move in the x direction in um.\n        y: The amount to move in the y direction in um.\n        z: The amount to move in the z direction in um.\n        update_nodes: Whether `.nodes` should be updated or not. Setting this to\n            `False` largely speeds up moving, especially for big networks, but\n            `.nodes` or `.show` will not show the new xyz coordinates.\n    \"\"\"\n    for i in self._branches_in_view:\n        self.base.xyzr[i][:, :3] += np.array([x, y, z])\n    if update_nodes:\n        self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.move_to","title":"move_to(x=0.0, y=0.0, z=0.0, update_nodes=False)","text":"

Move cells or networks to a location (x, y, z).

If x, y, and z are floats, then the first compartment of the first branch of the first cell is moved to that float coordinate, and everything else is shifted by the difference between that compartment\u2019s previous coordinate and the new float location.

If x, y, and z are arrays, then they must each have a length equal to the number of cells being moved. Then the first compartment of the first branch of each cell is moved to the specified location.

Parameters:

Name Type Description Default update_nodes bool

Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

False Source code in jaxley/modules/base.py
def move_to(\n    self,\n    x: Union[float, np.ndarray] = 0.0,\n    y: Union[float, np.ndarray] = 0.0,\n    z: Union[float, np.ndarray] = 0.0,\n    update_nodes: bool = False,\n):\n    \"\"\"Move cells or networks to a location (x, y, z).\n\n    If x, y, and z are floats, then the first compartment of the first branch\n    of the first cell is moved to that float coordinate, and everything else is\n    shifted by the difference between that compartment's previous coordinate and\n    the new float location.\n\n    If x, y, and z are arrays, then they must each have a length equal to the number\n    of cells being moved. Then the first compartment of the first branch of each\n    cell is moved to the specified location.\n\n    Args:\n        update_nodes: Whether `.nodes` should be updated or not. Setting this to\n            `False` largely speeds up moving, especially for big networks, but\n            `.nodes` or `.show` will not show the new xyz coordinates.\n    \"\"\"\n    # Test if any coordinate values are NaN which would greatly affect moving\n    if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n        raise ValueError(\n            \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n        )\n\n    # can only iterate over cells for networks\n    # lambda makes sure that generator can be created multiple times\n    base_is_net = self.base._current_view == \"network\"\n    cells = lambda: (self.cells if base_is_net else [self])\n\n    root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()])\n    root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells\n    move_by = np.array([x, y, z]).T - root_xyz\n\n    if len(move_by.shape) == 1:\n        move_by = np.tile(move_by, (len(self._cells_in_view), 1))\n\n    for cell, offset in zip(cells(), move_by):\n        for idx in cell._branches_in_view:\n            self.base.xyzr[idx][:, :3] += offset\n    if update_nodes:\n        self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.rotate","title":"rotate(degrees, rotation_axis='xy', update_nodes=False)","text":"

Rotate jaxley modules clockwise. Used only for visualization.

This function is used only for visualization. It does not affect the simulation.

Parameters:

Name Type Description Default degrees float

How many degrees to rotate the module by.

required rotation_axis str

Either of {xy | xz | yz}.

'xy' Source code in jaxley/modules/base.py
def rotate(\n    self, degrees: float, rotation_axis: str = \"xy\", update_nodes: bool = False\n):\n    \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n    This function is used only for visualization. It does not affect the simulation.\n\n    Args:\n        degrees: How many degrees to rotate the module by.\n        rotation_axis: Either of {`xy` | `xz` | `yz`}.\n    \"\"\"\n    degrees = degrees / 180 * np.pi\n    if rotation_axis == \"xy\":\n        dims = [0, 1]\n    elif rotation_axis == \"xz\":\n        dims = [0, 2]\n    elif rotation_axis == \"yz\":\n        dims = [1, 2]\n    else:\n        raise ValueError\n\n    rotation_matrix = np.asarray(\n        [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]\n    )\n    for i in self._branches_in_view:\n        rot = np.dot(rotation_matrix, self.base.xyzr[i][:, dims].T).T\n        self.base.xyzr[i][:, dims] = rot\n    if update_nodes:\n        self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.scope","title":"scope(scope)","text":"

Return a View of the module with the specified scope.

For example cell.scope(\"global\").branch(2).scope(\"local\").comp(1) will return the 1st compartment of branch 2.

Parameters:

Name Type Description Default scope str

either \u201cglobal\u201d or \u201clocal\u201d.

required

Returns:

Type Description View

View with the specified scope.

Source code in jaxley/modules/base.py
def scope(self, scope: str) -> View:\n    \"\"\"Return a View of the module with the specified scope.\n\n    For example `cell.scope(\"global\").branch(2).scope(\"local\").comp(1)`\n    will return the 1st compartment of branch 2.\n\n    Args:\n        scope: either \"global\" or \"local\".\n\n    Returns:\n        View with the specified scope.\"\"\"\n    view = self.view\n    view.set_scope(scope)\n    return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.select","title":"select(nodes=None, edges=None, sorted=False)","text":"

Return View of the module filtered by specific node or edges indices.

Parameters:

Name Type Description Default nodes ndarray

indices of nodes to view. If None, all nodes are viewed.

None edges ndarray

indices of edges to view. If None, all edges are viewed.

None sorted bool

if True, nodes and edges are sorted.

False

Returns:

Type Description View

View for subset of selected nodes and/or edges.

Source code in jaxley/modules/base.py
def select(\n    self, nodes: np.ndarray = None, edges: np.ndarray = None, sorted: bool = False\n) -> View:\n    \"\"\"Return View of the module filtered by specific node or edges indices.\n\n    Args:\n        nodes: indices of nodes to view. If None, all nodes are viewed.\n        edges: indices of edges to view. If None, all edges are viewed.\n        sorted: if True, nodes and edges are sorted.\n\n    Returns:\n        View for subset of selected nodes and/or edges.\"\"\"\n\n    nodes = self._reformat_index(nodes) if nodes is not None else None\n    nodes = self._nodes_in_view if is_str_all(nodes) else nodes\n    nodes = np.sort(nodes) if sorted else nodes\n\n    edges = self._reformat_index(edges) if edges is not None else None\n    edges = self._edges_in_view if is_str_all(edges) else edges\n    edges = np.sort(edges) if sorted else edges\n\n    view = View(self, nodes, edges)\n    view._set_controlled_by_param(\"filter\")\n    return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set","title":"set(key, val)","text":"

Set parameter of module (or its view) to a new value.

Note that this function can not be called within jax.jit or jax.grad. Instead, it should be used set the parameters of the module before the simulation. Use .data_set() to set parameters during jax.jit or jax.grad.

Parameters:

Name Type Description Default key str

The name of the parameter to set.

required val Union[float, ndarray]

The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).

required Source code in jaxley/modules/base.py
def set(self, key: str, val: Union[float, jnp.ndarray]):\n    \"\"\"Set parameter of module (or its view) to a new value.\n\n    Note that this function can not be called within `jax.jit` or `jax.grad`.\n    Instead, it should be used set the parameters of the module **before** the\n    simulation. Use `.data_set()` to set parameters during `jax.jit` or\n    `jax.grad`.\n\n    Args:\n        key: The name of the parameter to set.\n        val: The value to set the parameter to. If it is `jnp.ndarray` then it\n            must be of shape `(len(num_compartments))`.\n    \"\"\"\n    if key in self.nodes.columns:\n        not_nan = ~self.nodes[key].isna().to_numpy()\n        self.base.nodes.loc[self._nodes_in_view[not_nan], key] = val\n    elif key in self.edges.columns:\n        not_nan = ~self.edges[key].isna().to_numpy()\n        self.base.edges.loc[self._edges_in_view[not_nan], key] = val\n    else:\n        raise KeyError(f\"Key '{key}' not found in nodes or edges\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set_ncomp","title":"set_ncomp(ncomp, min_radius=None)","text":"

Set the number of compartments with which the branch is discretized.

Parameters:

Name Type Description Default ncomp int

The number of compartments that the branch should be discretized into.

required min_radius Optional[float]

Only used if the morphology was read from an SWC file. If passed the radius is capped to be at least this value.

None Source code in jaxley/modules/base.py
def set_ncomp(\n    self,\n    ncomp: int,\n    min_radius: Optional[float] = None,\n):\n    \"\"\"Set the number of compartments with which the branch is discretized.\n\n    Args:\n        ncomp: The number of compartments that the branch should be discretized\n            into.\n        min_radius: Only used if the morphology was read from an SWC file. If passed\n            the radius is capped to be at least this value.\n\n    Raises:\n        - When there are stimuli in any compartment in the module.\n        - When there are recordings in any compartment in the module.\n        - When the channels of the compartments are not the same within the branch\n        that is modified.\n        - When the lengths of the compartments are not the same within the branch\n        that is modified.\n        - Unless the morphology was read from an SWC file, when the radiuses of the\n        compartments are not the same within the branch that is modified.\n    \"\"\"\n    assert len(self.base.externals) == 0, \"No stimuli allowed!\"\n    assert len(self.base.recordings) == 0, \"No recordings allowed!\"\n    assert len(self.base.trainable_params) == 0, \"No trainables allowed!\"\n\n    assert self.base._module_type != \"network\", \"This is not allowed for networks.\"\n    assert not (\n        self.base._module_type == \"cell\"\n        and len(self._branches_in_view) == len(self.base._branches_in_view)\n    ), \"This is not allowed for cells.\"\n\n    # Update all attributes that are affected by compartment structure.\n    view = self.nodes.copy()\n    all_nodes = self.base.nodes\n    start_idx = self.nodes[\"global_comp_index\"].to_numpy()[0]\n    ncomp_per_branch = self.base.ncomp_per_branch\n    channel_names = [c._name for c in self.base.channels]\n    channel_param_names = list(\n        chain(*[c.channel_params for c in self.base.channels])\n    )\n    channel_state_names = list(\n        chain(*[c.channel_states for c in self.base.channels])\n    )\n    radius_generating_fns = self.base._radius_generating_fns\n\n    within_branch_radiuses = view[\"radius\"].to_numpy()\n    compartment_lengths = view[\"length\"].to_numpy()\n    num_previous_ncomp = len(within_branch_radiuses)\n    branch_indices = pd.unique(view[\"global_branch_index\"])\n\n    error_msg = lambda name: (\n        f\"You previously modified the {name} of individual compartments, but \"\n        f\"now you are modifying the number of compartments in this branch. \"\n        f\"This is not allowed. First build the morphology with `set_ncomp()` and \"\n        f\"then modify the radiuses and lengths of compartments.\"\n    )\n\n    if (\n        ~np.all(within_branch_radiuses == within_branch_radiuses[0])\n        and radius_generating_fns is None\n    ):\n        raise ValueError(error_msg(\"radius\"))\n\n    for property_name in [\"length\", \"capacitance\", \"axial_resistivity\"]:\n        compartment_properties = view[property_name].to_numpy()\n        if ~np.all(compartment_properties == compartment_properties[0]):\n            raise ValueError(error_msg(property_name))\n\n    if not (self.nodes[channel_names].var() == 0.0).all():\n        raise ValueError(\n            \"Some channel exists only in some compartments of the branch which you\"\n            \"are trying to modify. This is not allowed. First specify the number\"\n            \"of compartments with `.set_ncomp()` and then insert the channels\"\n            \"accordingly.\"\n        )\n\n    if not (\n        self.nodes[channel_param_names + channel_state_names].var() == 0.0\n    ).all():\n        raise ValueError(\n            \"Some channel has different parameters or states between the \"\n            \"different compartments of the branch which you are trying to modify. \"\n            \"This is not allowed. First specify the number of compartments with \"\n            \"`.set_ncomp()` and then insert the channels accordingly.\"\n        )\n\n    # Add new rows as the average of all rows. Special case for the length is below.\n    average_row = self.nodes.mean(skipna=False)\n    average_row = average_row.to_frame().T\n    view = pd.concat([*[average_row] * ncomp], axis=\"rows\")\n\n    # Set the correct datatype after having performed an average which cast\n    # everything to float.\n    integer_cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n    view[integer_cols] = view[integer_cols].astype(int)\n\n    # Whether or not a channel exists in a compartment is a boolean.\n    boolean_cols = channel_names\n    view[boolean_cols] = view[boolean_cols].astype(bool)\n\n    # Special treatment for the lengths and radiuses. These are not being set as\n    # the average because we:\n    # 1) Want to maintain the total length of a branch.\n    # 2) Want to use the SWC inferred radius.\n    #\n    # Compute new compartment lengths.\n    comp_lengths = np.sum(compartment_lengths) / ncomp\n    view[\"length\"] = comp_lengths\n\n    # Compute new compartment radiuses.\n    if radius_generating_fns is not None:\n        view[\"radius\"] = build_radiuses_from_xyzr(\n            radius_fns=radius_generating_fns,\n            branch_indices=branch_indices,\n            min_radius=min_radius,\n            ncomp=ncomp,\n        )\n    else:\n        view[\"radius\"] = within_branch_radiuses[0] * np.ones(ncomp)\n\n    # Update `.nodes`.\n    # 1) Delete N rows starting from start_idx\n    number_deleted = num_previous_ncomp\n    all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted))\n\n    # 2) Insert M new rows at the same location\n    df1 = all_nodes.iloc[:start_idx]  # Rows before the insertion point\n    df2 = all_nodes.iloc[start_idx:]  # Rows after the insertion point\n\n    # 3) Combine the parts: before, new rows, and after\n    all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True)\n\n    # Override `comp_index` to just be a consecutive list.\n    all_nodes[\"global_comp_index\"] = np.arange(len(all_nodes))\n\n    # Update compartment structure arguments.\n    ncomp_per_branch[branch_indices] = ncomp\n    ncomp = int(np.max(ncomp_per_branch))\n    cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n    internal_node_inds = np.arange(cumsum_ncomp[-1])\n\n    self.base.nodes = all_nodes\n    self.base.ncomp_per_branch = ncomp_per_branch\n    self.base.ncomp = ncomp\n    self.base.cumsum_ncomp = cumsum_ncomp\n    self.base._internal_node_inds = internal_node_inds\n\n    # Update the morphology indexing (e.g., `.comp_edges`).\n    self.base._initialize()\n    self.base._init_view()\n    self.base._update_local_indices()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set_scope","title":"set_scope(scope)","text":"

Toggle between \u201cglobal\u201d or \u201clocal\u201d scope.

Determines if global or local indices are used for viewing the module.

Parameters:

Name Type Description Default scope str

either \u201cglobal\u201d or \u201clocal\u201d.

required Source code in jaxley/modules/base.py
def set_scope(self, scope: str):\n    \"\"\"Toggle between \"global\" or \"local\" scope.\n\n    Determines if global or local indices are used for viewing the module.\n\n    Args:\n        scope: either \"global\" or \"local\".\"\"\"\n    assert scope in [\"global\", \"local\"], \"Invalid scope.\"\n    self._scope = scope\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.show","title":"show(param_names=None, *, indices=True, params=True, states=True, channel_names=None)","text":"

Print detailed information about the Module or a view of it.

Parameters:

Name Type Description Default param_names Optional[Union[str, List[str]]]

The names of the parameters to show. If None, all parameters are shown.

None indices bool

Whether to show the indices of the compartments.

True params bool

Whether to show the parameters of the compartments.

True states bool

Whether to show the states of the compartments.

True channel_names Optional[List[str]]

The names of the channels to show. If None, all channels are shown.

None

Returns:

Type Description DataFrame

A pd.DataFrame with the requested information.

Source code in jaxley/modules/base.py
def show(\n    self,\n    param_names: Optional[Union[str, List[str]]] = None,\n    *,\n    indices: bool = True,\n    params: bool = True,\n    states: bool = True,\n    channel_names: Optional[List[str]] = None,\n) -> pd.DataFrame:\n    \"\"\"Print detailed information about the Module or a view of it.\n\n    Args:\n        param_names: The names of the parameters to show. If `None`, all parameters\n            are shown.\n        indices: Whether to show the indices of the compartments.\n        params: Whether to show the parameters of the compartments.\n        states: Whether to show the states of the compartments.\n        channel_names: The names of the channels to show. If `None`, all channels are\n            shown.\n\n    Returns:\n        A `pd.DataFrame` with the requested information.\n    \"\"\"\n    nodes = self.nodes.copy()  # prevents this from being edited\n\n    cols = []\n    inds = [\"comp_index\", \"branch_index\", \"cell_index\"]\n    scopes = [\"local\", \"global\"]\n    inds = [f\"{s}_{i}\" for i in inds for s in scopes] if indices else []\n    cols += inds\n    cols += [ch._name for ch in self.channels] if channel_names else []\n    cols += (\n        sum([list(ch.channel_params) for ch in self.channels], []) if params else []\n    )\n    cols += (\n        sum([list(ch.channel_states) for ch in self.channels], []) if states else []\n    )\n\n    if not param_names is None:\n        cols = (\n            inds + [c for c in cols if c in param_names]\n            if params\n            else list(param_names)\n        )\n\n    return nodes[cols]\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.step","title":"step(u, delta_t, external_inds, externals, params, solver='bwd_euler', voltage_solver='jaxley.stone')","text":"

One step of solving the Ordinary Differential Equation.

This function is called inside of integrate and increments the state of the module by one time step. Calls _step_channels and _step_synapse to update the states of the channels and synapses using fwd_euler.

Parameters:

Name Type Description Default u Dict[str, ndarray]

The state of the module. voltages = u[\u201cv\u201d]

required delta_t float

The time step.

required external_inds Dict[str, ndarray]

The indices of the external inputs.

required externals Dict[str, ndarray]

The external inputs.

required params Dict[str, ndarray]

The parameters of the module.

required solver str

The solver to use for the voltages. Either of [\u201cbwd_euler\u201d, \u201cfwd_euler\u201d, \u201ccrank_nicolson\u201d].

'bwd_euler' voltage_solver str

The tridiagonal solver used to diagonalize the coefficient matrix of the ODE system. Either of [\u201cjaxley.thomas\u201d, \u201cjaxley.stone\u201d].

'jaxley.stone'

Returns:

Type Description Dict[str, ndarray]

The updated state of the module.

Source code in jaxley/modules/base.py
@only_allow_module\ndef step(\n    self,\n    u: Dict[str, jnp.ndarray],\n    delta_t: float,\n    external_inds: Dict[str, jnp.ndarray],\n    externals: Dict[str, jnp.ndarray],\n    params: Dict[str, jnp.ndarray],\n    solver: str = \"bwd_euler\",\n    voltage_solver: str = \"jaxley.stone\",\n) -> Dict[str, jnp.ndarray]:\n    \"\"\"One step of solving the Ordinary Differential Equation.\n\n    This function is called inside of `integrate` and increments the state of the\n    module by one time step. Calls `_step_channels` and `_step_synapse` to update\n    the states of the channels and synapses using fwd_euler.\n\n    Args:\n        u: The state of the module. voltages = u[\"v\"]\n        delta_t: The time step.\n        external_inds: The indices of the external inputs.\n        externals: The external inputs.\n        params: The parameters of the module.\n        solver: The solver to use for the voltages. Either of [\"bwd_euler\",\n            \"fwd_euler\", \"crank_nicolson\"].\n        voltage_solver: The tridiagonal solver used to diagonalize the\n            coefficient matrix of the ODE system. Either of [\"jaxley.thomas\",\n            \"jaxley.stone\"].\n\n    Returns:\n        The updated state of the module.\n    \"\"\"\n\n    # Extract the voltages\n    voltages = u[\"v\"]\n\n    # Extract the external inputs\n    if \"i\" in externals.keys():\n        i_current = externals[\"i\"]\n        i_inds = external_inds[\"i\"]\n        i_ext = self._get_external_input(\n            voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n        )\n    else:\n        i_ext = 0.0\n\n    # Step of the channels.\n    u, (v_terms, const_terms) = self._step_channels(\n        u, delta_t, self.channels, self.nodes, params\n    )\n\n    # Step of the synapse.\n    u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n        u,\n        self.synapses,\n        params,\n        delta_t,\n        self.edges,\n    )\n\n    # Clamp for channels and synapses.\n    for key in externals.keys():\n        if key not in [\"i\", \"v\"]:\n            u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n    # Voltage steps.\n    cm = params[\"capacitance\"]  # Abbreviation.\n\n    # Arguments used by all solvers.\n    solver_kwargs = {\n        \"voltages\": voltages,\n        \"voltage_terms\": (v_terms + syn_v_terms) / cm,\n        \"constant_terms\": (const_terms + i_ext + syn_const_terms) / cm,\n        \"axial_conductances\": params[\"axial_conductances\"],\n        \"internal_node_inds\": self._internal_node_inds,\n    }\n\n    # Add solver specific arguments.\n    if voltage_solver == \"jax.sparse\":\n        solver_kwargs.update(\n            {\n                \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n                \"data_inds\": self._data_inds,\n                \"indices\": self._indices_jax_spsolve,\n                \"indptr\": self._indptr_jax_spsolve,\n                \"n_nodes\": self._n_nodes,\n            }\n        )\n        # Only for `bwd_euler` and `cranck-nicolson`.\n        step_voltage_implicit = step_voltage_implicit_with_jax_spsolve\n    else:\n        # Our custom sparse solver requires a different format of all conductance\n        # values to perform triangulation and backsubstution optimally.\n        #\n        # Currently, the forward Euler solver also uses this format. However,\n        # this is only for historical reasons and we are planning to change this in\n        # the future.\n        solver_kwargs.update(\n            {\n                \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n                \"sources\": np.asarray(self._comp_edges[\"source\"].to_list()),\n                \"types\": np.asarray(self._comp_edges[\"type\"].to_list()),\n                \"ncomp_per_branch\": self.ncomp_per_branch,\n                \"par_inds\": self._par_inds,\n                \"child_inds\": self._child_inds,\n                \"nbranches\": self.total_nbranches,\n                \"solver\": voltage_solver,\n                \"idx\": self._solve_indexer,\n                \"debug_states\": self.debug_states,\n            }\n        )\n        # Only for `bwd_euler` and `cranck-nicolson`.\n        step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve\n\n    if solver == \"bwd_euler\":\n        u[\"v\"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)\n    elif solver == \"crank_nicolson\":\n        # Crank-Nicolson advances by half a step of backward and half a step of\n        # forward Euler.\n        half_step_delta_t = delta_t / 2\n        half_step_voltages = step_voltage_implicit(\n            **solver_kwargs, delta_t=half_step_delta_t\n        )\n        # The forward Euler step in Crank-Nicolson can be performed easily as\n        # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4.\n        u[\"v\"] = 2 * half_step_voltages - voltages\n    elif solver == \"fwd_euler\":\n        u[\"v\"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t)\n    else:\n        raise ValueError(\n            f\"You specified `solver={solver}`. The only allowed solvers are \"\n            \"['bwd_euler', 'fwd_euler', 'crank_nicolson'].\"\n        )\n\n    # Clamp for voltages.\n    if \"v\" in externals.keys():\n        u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n    return u\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.stimulate","title":"stimulate(current=None, verbose=True)","text":"

Insert a stimulus into the compartment.

current must be a 1d array or have batch dimension of size (num_compartments, ) or (1, ). If 1d, the same stimulus is added to all compartments.

This function cannot be run during jax.jit and jax.grad. Because of this, it should only be used for static stimuli (i.e., stimuli that do not depend on the data and that should not be learned). For stimuli that depend on data (or that should be learned), please use data_stimulate().

Parameters:

Name Type Description Default current Optional[ndarray]

Current in nA.

None Source code in jaxley/modules/base.py
def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n    \"\"\"Insert a stimulus into the compartment.\n\n    current must be a 1d array or have batch dimension of size `(num_compartments, )`\n    or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n    This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n    it should only be used for static stimuli (i.e., stimuli that do not depend\n    on the data and that should not be learned). For stimuli that depend on data\n    (or that should be learned), please use `data_stimulate()`.\n\n    Args:\n        current: Current in `nA`.\n    \"\"\"\n    self._external_input(\"i\", current, verbose=verbose)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.to_jax","title":"to_jax()","text":"

Move .nodes to .jaxnodes.

Before the actual simulation is run (via jx.integrate), all parameters of the jx.Module are stored in .nodes (a pd.DataFrame). However, for simulation, these parameters have to be moved to be jnp.ndarrays such that they can be processed on GPU/TPU and such that the simulation can be differentiated. .to_jax() copies the .nodes to .jaxnodes.

Source code in jaxley/modules/base.py
@only_allow_module\ndef to_jax(self):\n    # TODO FROM #447: Make this work for View?\n    \"\"\"Move `.nodes` to `.jaxnodes`.\n\n    Before the actual simulation is run (via `jx.integrate`), all parameters of\n    the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n    simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n    they can be processed on GPU/TPU and such that the simulation can be\n    differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n    \"\"\"\n    self.base.jaxnodes = {}\n    for key, value in self.base.nodes.to_dict(orient=\"list\").items():\n        inds = jnp.arange(len(value))\n        self.base.jaxnodes[key] = jnp.asarray(value)[inds]\n\n    # `jaxedges` contains only parameters (no indices).\n    # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n    # we allow parameter sharing.\n    self.base.jaxedges = {}\n    edges = self.base.edges.to_dict(orient=\"list\")\n    for i, synapse in enumerate(self.base.synapses):\n        condition = np.asarray(edges[\"type_ind\"]) == i\n        for key in synapse.synapse_params:\n            self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n        for key in synapse.synapse_states:\n            self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.vis","title":"vis(ax=None, color='k', dims=(0, 1), type='line', **kwargs)","text":"

Visualize the module.

Modules can be visualized on one of the cardinal planes (xy, xz, yz) or even in 3D.

Several options are available: - line: All points from the traced morphology (xyzr), are connected with a line plot. - scatter: All traced points, are plotted as scatter points. - comp: Plots the compartmentalized morphology, including radius and shape. (shows the true compartment lengths per default, but this can be changed via the kwargs, for details see jaxley.utils.plot_utils.plot_comps). - morph: Reconstructs the 3D shape of the traced morphology. For details see jaxley.utils.plot_utils.plot_morph. Warning: For 3D plots and morphologies with many traced points this can be very slow.

Parameters:

Name Type Description Default ax Optional[Axes]

An axis into which to plot.

None color str

The color for all branches.

'k' dims Tuple[int]

Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

(0, 1) type str

The type of plot. One of [\u201cline\u201d, \u201cscatter\u201d, \u201ccomp\u201d, \u201cmorph\u201d].

'line' kwargs

Keyword arguments passed to the plotting function.

{} Source code in jaxley/modules/base.py
def vis(\n    self,\n    ax: Optional[Axes] = None,\n    color: str = \"k\",\n    dims: Tuple[int] = (0, 1),\n    type: str = \"line\",\n    **kwargs,\n) -> Axes:\n    \"\"\"Visualize the module.\n\n    Modules can be visualized on one of the cardinal planes (xy, xz, yz) or\n    even in 3D.\n\n    Several options are available:\n    - `line`: All points from the traced morphology (`xyzr`), are connected\n    with a line plot.\n    - `scatter`: All traced points, are plotted as scatter points.\n    - `comp`: Plots the compartmentalized morphology, including radius\n    and shape. (shows the true compartment lengths per default, but this can\n    be changed via the `kwargs`, for details see\n    `jaxley.utils.plot_utils.plot_comps`).\n    - `morph`: Reconstructs the 3D shape of the traced morphology. For details see\n    `jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies\n    with many traced points this can be very slow.\n\n    Args:\n        ax: An axis into which to plot.\n        color: The color for all branches.\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two of them.\n        type: The type of plot. One of [\"line\", \"scatter\", \"comp\", \"morph\"].\n        kwargs: Keyword arguments passed to the plotting function.\n    \"\"\"\n    res = 100 if \"resolution\" not in kwargs else kwargs.pop(\"resolution\")\n    if \"comp\" in type.lower():\n        return plot_comps(\n            self, dims=dims, ax=ax, color=color, resolution=res, **kwargs\n        )\n    if \"morph\" in type.lower():\n        return plot_morph(\n            self, dims=dims, ax=ax, color=color, resolution=res, **kwargs\n        )\n\n    assert not np.any(\n        [np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]\n    ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n\n    ax = plot_graph(\n        self.xyzr,\n        dims=dims,\n        color=color,\n        ax=ax,\n        type=type,\n        **kwargs,\n    )\n\n    return ax\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.write_trainables","title":"write_trainables(trainable_params)","text":"

Write the trainables into .nodes and .edges.

This allows to, e.g., visualize trained networks with .vis().

Parameters:

Name Type Description Default trainable_params List[Dict[str, ndarray]]

The trainable parameters returned by get_parameters().

required Source code in jaxley/modules/base.py
def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):\n    \"\"\"Write the trainables into `.nodes` and `.edges`.\n\n    This allows to, e.g., visualize trained networks with `.vis()`.\n\n    Args:\n        trainable_params: The trainable parameters returned by `get_parameters()`.\n    \"\"\"\n    # We do not support views. Why? `jaxedges` does not have any NaN\n    # elements, whereas edges does. Because of this, we already need special\n    # treatment to make this function work, and it would be an even bigger hassle\n    # if we wanted to support this.\n    assert self.__class__.__name__ in [\n        \"Compartment\",\n        \"Branch\",\n        \"Cell\",\n        \"Network\",\n    ], \"Only supports modules.\"\n\n    # We could also implement this without casting the module to jax.\n    # However, I think it allows us to reuse as much code as possible and it avoids\n    # any kind of issues with indexing or parameter sharing (as this is fully\n    # taken care of by `get_all_parameters()`).\n    self.base.to_jax()\n    pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)\n    all_params = self.base.get_all_parameters(pstate, voltage_solver=\"jaxley.stone\")\n\n    # The value for `delta_t` does not matter here because it is only used to\n    # compute the initial current. However, the initial current cannot be made\n    # trainable and so its value never gets used below.\n    all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025)\n\n    # Loop only over the keys in `pstate` to avoid unnecessary computation.\n    for parameter in pstate:\n        key = parameter[\"key\"]\n        if key in self.base.nodes.columns:\n            vals_to_set = all_params if key in all_params.keys() else all_states\n            self.base.nodes[key] = vals_to_set[key]\n\n    # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n    # we allow parameter sharing.\n    edges = self.base.edges.to_dict(orient=\"list\")\n    for i, synapse in enumerate(self.base.synapses):\n        condition = np.asarray(edges[\"type_ind\"]) == i\n        for key in list(synapse.synapse_params.keys()):\n            self.base.edges.loc[condition, key] = all_params[key]\n        for key in list(synapse.synapse_states.keys()):\n            self.base.edges.loc[condition, key] = all_states[key]\n
"},{"location":"reference/modules/#compartment","title":"Compartment","text":"

Bases: Module

Compartment class.

This class defines a single compartment that can be simulated by itself or connected up into branches. It is the basic building block of a neuron model.

Source code in jaxley/modules/compartment.py
class Compartment(Module):\n    \"\"\"Compartment class.\n\n    This class defines a single compartment that can be simulated by itself or\n    connected up into branches. It is the basic building block of a neuron model.\n    \"\"\"\n\n    compartment_params: Dict = {\n        \"length\": 10.0,  # um\n        \"radius\": 1.0,  # um\n        \"axial_resistivity\": 5_000.0,  # ohm cm\n        \"capacitance\": 1.0,  # uF/cm^2\n    }\n    compartment_states: Dict = {\"v\": -70.0}\n\n    def __init__(self):\n        super().__init__()\n\n        self.ncomp = 1\n        self.ncomp_per_branch = np.asarray([1])\n        self.total_nbranches = 1\n        self.nbranches_per_cell = [1]\n        self._cumsum_nbranches = np.asarray([0, 1])\n        self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n        # Setting up the `nodes` for indexing.\n        self.nodes = pd.DataFrame(\n            dict(global_cell_index=[0], global_branch_index=[0], global_comp_index=[0])\n        )\n        self._append_params_and_states(self.compartment_params, self.compartment_states)\n        self._update_local_indices()\n        self._init_view()\n\n        # Synapses.\n        self.branch_edges = pd.DataFrame(\n            dict(parent_branch_index=[], child_branch_index=[])\n        )\n\n        # For morphology indexing.\n        self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n            compute_children_and_parents(self.branch_edges)\n        )\n        self._internal_node_inds = jnp.asarray([0])\n\n        # Initialize the module.\n        self._initialize()\n\n        # Coordinates.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n    def _init_morph_jaxley_spsolve(self):\n        self._solve_indexer = JaxleySolveIndexer(\n            cumsum_ncomp=self.cumsum_ncomp,\n            branchpoint_group_inds=np.asarray([]).astype(int),\n            children_in_level=[],\n            parents_in_level=[],\n            root_inds=np.asarray([0]),\n            remapped_node_indices=self._internal_node_inds,\n        )\n\n    def _init_morph_jax_spsolve(self):\n        \"\"\"Initialize morphology for the jax sparse voltage solver.\n\n        Explanation of `self._comp_eges['type']`:\n        `type == 0`: compartment <--> compartment (within branch)\n        `type == 1`: branchpoint --> parent-compartment\n        `type == 2`: branchpoint --> child-compartment\n        `type == 3`: parent-compartment --> branchpoint\n        `type == 4`: child-compartment --> branchpoint\n        \"\"\"\n        self._comp_edges = pd.DataFrame().from_dict(\n            {\"source\": [], \"sink\": [], \"type\": []}\n        )\n        n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n        self._n_nodes = n_nodes\n        self._data_inds = data_inds\n        self._indices_jax_spsolve = indices\n        self._indptr_jax_spsolve = indptr\n
"},{"location":"reference/modules/#branch","title":"Branch","text":"

Bases: Module

Branch class.

This class defines a single branch that can be simulated by itself or connected to build a cell. A branch is linear segment of several compartments and can be connected to no, one or more other branches at each end to build more intricate cell morphologies.

Source code in jaxley/modules/branch.py
class Branch(Module):\n    \"\"\"Branch class.\n\n    This class defines a single branch that can be simulated by itself or\n    connected to build a cell. A branch is linear segment of several compartments\n    and can be connected to no, one or more other branches at each end to build more\n    intricate cell morphologies.\n    \"\"\"\n\n    branch_params: Dict = {}\n    branch_states: Dict = {}\n\n    @deprecated_kwargs(\"0.6.0\", [\"nseg\"])\n    def __init__(\n        self,\n        compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n        ncomp: Optional[int] = None,\n        nseg: Optional[int] = None,\n    ):\n        \"\"\"\n        Args:\n            compartments: A single compartment or a list of compartments that make up the\n                branch.\n            ncomp: Number of segments to divide the branch into. If `compartments` is an\n                a single compartment, than the compartment is repeated `ncomp` times to\n                create the branch.\n        \"\"\"\n        # Warnings and errors that deal with the change from `nseg` to `ncomp` change\n        # in Jaxley v0.5.0.\n        if ncomp is not None and nseg is not None:\n            raise ValueError(\"You passed `ncomp` and `nseg`. Please pass only `ncomp`.\")\n        if ncomp is None and nseg is not None:\n            ncomp = nseg\n\n        super().__init__()\n        assert (\n            isinstance(compartments, (Compartment, List)) or compartments is None\n        ), \"Only Compartment or List[Compartment] is allowed.\"\n        if isinstance(compartments, Compartment):\n            assert (\n                ncomp is not None\n            ), \"If `compartments` is not a list then you have to set `ncomp`.\"\n        compartments = Compartment() if compartments is None else compartments\n        ncomp = 1 if ncomp is None else ncomp\n\n        if isinstance(compartments, Compartment):\n            compartment_list = [compartments] * ncomp\n        else:\n            compartment_list = compartments\n\n        self.ncomp = len(compartment_list)\n        self.ncomp_per_branch = np.asarray([self.ncomp])\n        self.total_nbranches = 1\n        self.nbranches_per_cell = [1]\n        self._cumsum_nbranches = jnp.asarray([0, 1])\n        self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n        # Indexing.\n        self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n        self._append_params_and_states(self.branch_params, self.branch_states)\n        self.nodes[\"global_comp_index\"] = np.arange(self.ncomp).tolist()\n        self.nodes[\"global_branch_index\"] = [0] * self.ncomp\n        self.nodes[\"global_cell_index\"] = [0] * self.ncomp\n        self._update_local_indices()\n        self._init_view()\n\n        # Channels.\n        self._gather_channels_from_constituents(compartment_list)\n\n        self.branch_edges = pd.DataFrame(\n            dict(parent_branch_index=[], child_branch_index=[])\n        )\n\n        # For morphology indexing.\n        self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n            compute_children_and_parents(self.branch_edges)\n        )\n        self._internal_node_inds = jnp.arange(self.ncomp)\n\n        self._initialize()\n\n        # Coordinates.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n    def _init_morph_jaxley_spsolve(self):\n        self._solve_indexer = JaxleySolveIndexer(\n            cumsum_ncomp=self.cumsum_ncomp,\n            branchpoint_group_inds=np.asarray([]).astype(int),\n            remapped_node_indices=self._internal_node_inds,\n            children_in_level=[],\n            parents_in_level=[],\n            root_inds=np.asarray([0]),\n        )\n\n    def _init_morph_jax_spsolve(self):\n        \"\"\"Initialize morphology for the jax sparse voltage solver.\n\n        Explanation of `self._comp_eges['type']`:\n        `type == 0`: compartment <--> compartment (within branch)\n        `type == 1`: branchpoint --> parent-compartment\n        `type == 2`: branchpoint --> child-compartment\n        `type == 3`: parent-compartment --> branchpoint\n        `type == 4`: child-compartment --> branchpoint\n        \"\"\"\n        self._comp_edges = pd.DataFrame().from_dict(\n            {\n                \"source\": list(range(self.ncomp - 1)) + list(range(1, self.ncomp)),\n                \"sink\": list(range(1, self.ncomp)) + list(range(self.ncomp - 1)),\n            }\n        )\n        self._comp_edges[\"type\"] = 0\n        n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n        self._n_nodes = n_nodes\n        self._data_inds = data_inds\n        self._indices_jax_spsolve = indices\n        self._indptr_jax_spsolve = indptr\n\n    def __len__(self) -> int:\n        return self.ncomp\n
"},{"location":"reference/modules/#jaxley.modules.branch.Branch.__init__","title":"__init__(compartments=None, ncomp=None, nseg=None)","text":"

Parameters:

Name Type Description Default compartments Optional[Union[Compartment, List[Compartment]]]

A single compartment or a list of compartments that make up the branch.

None ncomp Optional[int]

Number of segments to divide the branch into. If compartments is an a single compartment, than the compartment is repeated ncomp times to create the branch.

None Source code in jaxley/modules/branch.py
@deprecated_kwargs(\"0.6.0\", [\"nseg\"])\ndef __init__(\n    self,\n    compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n    ncomp: Optional[int] = None,\n    nseg: Optional[int] = None,\n):\n    \"\"\"\n    Args:\n        compartments: A single compartment or a list of compartments that make up the\n            branch.\n        ncomp: Number of segments to divide the branch into. If `compartments` is an\n            a single compartment, than the compartment is repeated `ncomp` times to\n            create the branch.\n    \"\"\"\n    # Warnings and errors that deal with the change from `nseg` to `ncomp` change\n    # in Jaxley v0.5.0.\n    if ncomp is not None and nseg is not None:\n        raise ValueError(\"You passed `ncomp` and `nseg`. Please pass only `ncomp`.\")\n    if ncomp is None and nseg is not None:\n        ncomp = nseg\n\n    super().__init__()\n    assert (\n        isinstance(compartments, (Compartment, List)) or compartments is None\n    ), \"Only Compartment or List[Compartment] is allowed.\"\n    if isinstance(compartments, Compartment):\n        assert (\n            ncomp is not None\n        ), \"If `compartments` is not a list then you have to set `ncomp`.\"\n    compartments = Compartment() if compartments is None else compartments\n    ncomp = 1 if ncomp is None else ncomp\n\n    if isinstance(compartments, Compartment):\n        compartment_list = [compartments] * ncomp\n    else:\n        compartment_list = compartments\n\n    self.ncomp = len(compartment_list)\n    self.ncomp_per_branch = np.asarray([self.ncomp])\n    self.total_nbranches = 1\n    self.nbranches_per_cell = [1]\n    self._cumsum_nbranches = jnp.asarray([0, 1])\n    self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n    # Indexing.\n    self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n    self._append_params_and_states(self.branch_params, self.branch_states)\n    self.nodes[\"global_comp_index\"] = np.arange(self.ncomp).tolist()\n    self.nodes[\"global_branch_index\"] = [0] * self.ncomp\n    self.nodes[\"global_cell_index\"] = [0] * self.ncomp\n    self._update_local_indices()\n    self._init_view()\n\n    # Channels.\n    self._gather_channels_from_constituents(compartment_list)\n\n    self.branch_edges = pd.DataFrame(\n        dict(parent_branch_index=[], child_branch_index=[])\n    )\n\n    # For morphology indexing.\n    self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n        compute_children_and_parents(self.branch_edges)\n    )\n    self._internal_node_inds = jnp.arange(self.ncomp)\n\n    self._initialize()\n\n    # Coordinates.\n    self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n
"},{"location":"reference/modules/#cell","title":"Cell","text":"

Bases: Module

Cell class.

This class defines a single cell that can be simulated by itself or connected with synapses to build a network. A cell is made up of several branches and supports intricate cell morphologies.

Source code in jaxley/modules/cell.py
class Cell(Module):\n    \"\"\"Cell class.\n\n    This class defines a single cell that can be simulated by itself or\n    connected with synapses to build a network. A cell is made up of several branches\n    and supports intricate cell morphologies.\n    \"\"\"\n\n    cell_params: Dict = {}\n    cell_states: Dict = {}\n\n    def __init__(\n        self,\n        branches: Optional[Union[Branch, List[Branch]]] = None,\n        parents: Optional[List[int]] = None,\n        xyzr: Optional[List[np.ndarray]] = None,\n    ):\n        \"\"\"Initialize a cell.\n\n        Args:\n            branches: A single branch or a list of branches that make up the cell.\n                If a single branch is provided, then the branch is repeated `len(parents)`\n                times to create the cell.\n            parents: The parent branch index for each branch. The first branch has no\n                parent and is therefore set to -1.\n            xyzr: For every branch, the x, y, and z coordinates and the radius at the\n                traced coordinates. Note that this is the full tracing (from SWC), not\n                the stick representation coordinates.\n        \"\"\"\n        super().__init__()\n        assert (\n            isinstance(branches, (Branch, List)) or branches is None\n        ), \"Only Branch or List[Branch] is allowed.\"\n        if branches is not None:\n            assert (\n                parents is not None\n            ), \"If `branches` is not a list then you have to set `parents`.\"\n        if isinstance(branches, List):\n            assert len(parents) == len(\n                branches\n            ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n        branches = Branch() if branches is None else branches\n        parents = [-1] if parents is None else parents\n\n        if isinstance(branches, Branch):\n            branch_list = [branches for _ in range(len(parents))]\n        else:\n            branch_list = branches\n\n        if xyzr is not None:\n            assert len(xyzr) == len(parents)\n            self.xyzr = xyzr\n        else:\n            # For every branch (`len(parents)`), we have a start and end point (`2`) and\n            # a (x,y,z,r) coordinate for each of them (`4`).\n            # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n            # (potentially learned) length of every compartment, we only populate\n            # self.xyzr at `.vis()`.\n            self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n        self.total_nbranches = len(branch_list)\n        self.nbranches_per_cell = [len(branch_list)]\n        self.comb_parents = jnp.asarray(parents)\n        self.comb_children = compute_children_indices(self.comb_parents)\n        self._cumsum_nbranches = np.asarray([0, len(branch_list)])\n\n        # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()`\n        # is run.\n        self.ncomp_per_branch = np.asarray([branch.ncomp for branch in branch_list])\n        self.ncomp = int(np.max(self.ncomp_per_branch))\n        self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n        self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n\n        # Build nodes. Has to be changed when `.set_ncomp()` is run.\n        self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n        self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n        self.nodes[\"global_branch_index\"] = np.repeat(\n            np.arange(self.total_nbranches), self.ncomp_per_branch\n        ).tolist()\n        self.nodes[\"global_cell_index\"] = np.repeat(0, self.cumsum_ncomp[-1]).tolist()\n        self._update_local_indices()\n        self._init_view()\n\n        # Appending general parameters (radius, length, r_a, cm) and channel parameters,\n        # as well as the states (v, and channel states).\n        self._append_params_and_states(self.cell_params, self.cell_states)\n\n        # Channels.\n        self._gather_channels_from_constituents(branch_list)\n\n        self.branch_edges = pd.DataFrame(\n            dict(\n                parent_branch_index=self.comb_parents[1:],\n                child_branch_index=np.arange(1, self.total_nbranches),\n            )\n        )\n\n        # For morphology indexing.\n        self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n            compute_children_and_parents(self.branch_edges)\n        )\n\n        self._initialize()\n\n    def _init_morph_jaxley_spsolve(self):\n        \"\"\"Initialize morphology for the custom sparse solver.\n\n        Running this function is only required for custom Jaxley solvers, i.e., for\n        `voltage_solver={'jaxley.stone', 'jaxley.thomas'}`. However, because at\n        `.__init__()` (when the function is run), we do not yet know which solver the\n        user will use. Therefore, we always run this function at `.__init__()`.\n        \"\"\"\n        children_and_parents = compute_morphology_indices_in_levels(\n            len(self._par_inds),\n            self._child_belongs_to_branchpoint,\n            self._par_inds,\n            self._child_inds,\n        )\n        branchpoint_group_inds = build_branchpoint_group_inds(\n            len(self._par_inds),\n            self._child_belongs_to_branchpoint,\n            self.cumsum_ncomp[-1],\n        )\n        parents = self.comb_parents\n        children_inds = children_and_parents[\"children\"]\n        parents_inds = children_and_parents[\"parents\"]\n\n        levels = compute_levels(parents)\n        children_in_level = compute_children_in_level(levels, children_inds)\n        parents_in_level = compute_parents_in_level(\n            levels, self._par_inds, parents_inds\n        )\n        levels_and_ncomp = pd.DataFrame().from_dict(\n            {\n                \"levels\": levels,\n                \"ncomps\": self.ncomp_per_branch,\n            }\n        )\n        levels_and_ncomp[\"max_ncomp_in_level\"] = levels_and_ncomp.groupby(\"levels\")[\n            \"ncomps\"\n        ].transform(\"max\")\n        padded_cumsum_ncomp = cumsum_leading_zero(\n            levels_and_ncomp[\"max_ncomp_in_level\"].to_numpy()\n        )\n\n        # Generate mapping to deal with the masking which allows using the custom\n        # sparse solver to deal with different ncomp per branch.\n        remapped_node_indices = remap_index_to_masked(\n            self._internal_node_inds,\n            self.nodes,\n            padded_cumsum_ncomp,\n            self.ncomp_per_branch,\n        )\n        self._solve_indexer = JaxleySolveIndexer(\n            cumsum_ncomp=padded_cumsum_ncomp,\n            branchpoint_group_inds=branchpoint_group_inds,\n            children_in_level=children_in_level,\n            parents_in_level=parents_in_level,\n            root_inds=np.asarray([0]),\n            remapped_node_indices=remapped_node_indices,\n        )\n\n    def _init_morph_jax_spsolve(self):\n        \"\"\"For morphology indexing with the `jax.sparse` voltage volver.\n\n        Explanation of `self._comp_eges['type']`:\n        `type == 0`: compartment <--> compartment (within branch)\n        `type == 1`: branchpoint --> parent-compartment\n        `type == 2`: branchpoint --> child-compartment\n        `type == 3`: parent-compartment --> branchpoint\n        `type == 4`: child-compartment --> branchpoint\n\n        Running this function is only required for generic sparse solvers, i.e., for\n        `voltage_solver='jax.sparse'`.\n        \"\"\"\n\n        # Edges between compartments within the branches.\n        self._comp_edges = pd.concat(\n            [\n                pd.DataFrame()\n                .from_dict(\n                    {\n                        \"source\": list(range(cumsum_ncomp, ncomp - 1 + cumsum_ncomp))\n                        + list(range(1 + cumsum_ncomp, ncomp + cumsum_ncomp)),\n                        \"sink\": list(range(1 + cumsum_ncomp, ncomp + cumsum_ncomp))\n                        + list(range(cumsum_ncomp, ncomp - 1 + cumsum_ncomp)),\n                    }\n                )\n                .astype(int)\n                for ncomp, cumsum_ncomp in zip(self.ncomp_per_branch, self.cumsum_ncomp)\n            ]\n        )\n        self._comp_edges[\"type\"] = 0\n\n        # Edges from branchpoints to compartments.\n        branchpoint_to_parent_edges = pd.DataFrame().from_dict(\n            {\n                \"source\": np.arange(len(self._par_inds)) + self.cumsum_ncomp[-1],\n                \"sink\": self.cumsum_ncomp[self._par_inds + 1] - 1,\n                \"type\": 1,\n            }\n        )\n        branchpoint_to_child_edges = pd.DataFrame().from_dict(\n            {\n                \"source\": self._child_belongs_to_branchpoint + self.cumsum_ncomp[-1],\n                \"sink\": self.cumsum_ncomp[self._child_inds],\n                \"type\": 2,\n            }\n        )\n        self._comp_edges = pd.concat(\n            [\n                self._comp_edges,\n                branchpoint_to_parent_edges,\n                branchpoint_to_child_edges,\n            ],\n            ignore_index=True,\n        )\n\n        # Edges from compartments to branchpoints.\n        parent_to_branchpoint_edges = branchpoint_to_parent_edges.rename(\n            columns={\"sink\": \"source\", \"source\": \"sink\"}\n        )\n        parent_to_branchpoint_edges[\"type\"] = 3\n        child_to_branchpoint_edges = branchpoint_to_child_edges.rename(\n            columns={\"sink\": \"source\", \"source\": \"sink\"}\n        )\n        child_to_branchpoint_edges[\"type\"] = 4\n\n        self._comp_edges = pd.concat(\n            [\n                self._comp_edges,\n                parent_to_branchpoint_edges,\n                child_to_branchpoint_edges,\n            ],\n            ignore_index=True,\n        )\n\n        n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n        self._n_nodes = n_nodes\n        self._data_inds = data_inds\n        self._indices_jax_spsolve = indices\n        self._indptr_jax_spsolve = indptr\n
"},{"location":"reference/modules/#jaxley.modules.cell.Cell.__init__","title":"__init__(branches=None, parents=None, xyzr=None)","text":"

Initialize a cell.

Parameters:

Name Type Description Default branches Optional[Union[Branch, List[Branch]]]

A single branch or a list of branches that make up the cell. If a single branch is provided, then the branch is repeated len(parents) times to create the cell.

None parents Optional[List[int]]

The parent branch index for each branch. The first branch has no parent and is therefore set to -1.

None xyzr Optional[List[ndarray]]

For every branch, the x, y, and z coordinates and the radius at the traced coordinates. Note that this is the full tracing (from SWC), not the stick representation coordinates.

None Source code in jaxley/modules/cell.py
def __init__(\n    self,\n    branches: Optional[Union[Branch, List[Branch]]] = None,\n    parents: Optional[List[int]] = None,\n    xyzr: Optional[List[np.ndarray]] = None,\n):\n    \"\"\"Initialize a cell.\n\n    Args:\n        branches: A single branch or a list of branches that make up the cell.\n            If a single branch is provided, then the branch is repeated `len(parents)`\n            times to create the cell.\n        parents: The parent branch index for each branch. The first branch has no\n            parent and is therefore set to -1.\n        xyzr: For every branch, the x, y, and z coordinates and the radius at the\n            traced coordinates. Note that this is the full tracing (from SWC), not\n            the stick representation coordinates.\n    \"\"\"\n    super().__init__()\n    assert (\n        isinstance(branches, (Branch, List)) or branches is None\n    ), \"Only Branch or List[Branch] is allowed.\"\n    if branches is not None:\n        assert (\n            parents is not None\n        ), \"If `branches` is not a list then you have to set `parents`.\"\n    if isinstance(branches, List):\n        assert len(parents) == len(\n            branches\n        ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n    branches = Branch() if branches is None else branches\n    parents = [-1] if parents is None else parents\n\n    if isinstance(branches, Branch):\n        branch_list = [branches for _ in range(len(parents))]\n    else:\n        branch_list = branches\n\n    if xyzr is not None:\n        assert len(xyzr) == len(parents)\n        self.xyzr = xyzr\n    else:\n        # For every branch (`len(parents)`), we have a start and end point (`2`) and\n        # a (x,y,z,r) coordinate for each of them (`4`).\n        # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n        # (potentially learned) length of every compartment, we only populate\n        # self.xyzr at `.vis()`.\n        self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n    self.total_nbranches = len(branch_list)\n    self.nbranches_per_cell = [len(branch_list)]\n    self.comb_parents = jnp.asarray(parents)\n    self.comb_children = compute_children_indices(self.comb_parents)\n    self._cumsum_nbranches = np.asarray([0, len(branch_list)])\n\n    # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()`\n    # is run.\n    self.ncomp_per_branch = np.asarray([branch.ncomp for branch in branch_list])\n    self.ncomp = int(np.max(self.ncomp_per_branch))\n    self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n    self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n\n    # Build nodes. Has to be changed when `.set_ncomp()` is run.\n    self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n    self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n    self.nodes[\"global_branch_index\"] = np.repeat(\n        np.arange(self.total_nbranches), self.ncomp_per_branch\n    ).tolist()\n    self.nodes[\"global_cell_index\"] = np.repeat(0, self.cumsum_ncomp[-1]).tolist()\n    self._update_local_indices()\n    self._init_view()\n\n    # Appending general parameters (radius, length, r_a, cm) and channel parameters,\n    # as well as the states (v, and channel states).\n    self._append_params_and_states(self.cell_params, self.cell_states)\n\n    # Channels.\n    self._gather_channels_from_constituents(branch_list)\n\n    self.branch_edges = pd.DataFrame(\n        dict(\n            parent_branch_index=self.comb_parents[1:],\n            child_branch_index=np.arange(1, self.total_nbranches),\n        )\n    )\n\n    # For morphology indexing.\n    self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n        compute_children_and_parents(self.branch_edges)\n    )\n\n    self._initialize()\n
"},{"location":"reference/modules/#network","title":"Network","text":"

Bases: Module

Network class.

This class defines a network of cells that can be connected with synapses.

Source code in jaxley/modules/network.py
class Network(Module):\n    \"\"\"Network class.\n\n    This class defines a network of cells that can be connected with synapses.\n    \"\"\"\n\n    network_params: Dict = {}\n    network_states: Dict = {}\n\n    def __init__(\n        self,\n        cells: List[Cell],\n    ):\n        \"\"\"Initialize network of cells and synapses.\n\n        Args:\n            cells: A list of cells that make up the network.\n        \"\"\"\n        super().__init__()\n        for cell in cells:\n            self.xyzr += deepcopy(cell.xyzr)\n\n        self._cells_list = cells\n        self.ncomp_per_branch = np.concatenate(\n            [cell.ncomp_per_branch for cell in cells]\n        )\n        self.ncomp = int(np.max(self.ncomp_per_branch))\n        self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n        self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n        self._append_params_and_states(self.network_params, self.network_states)\n\n        self.nbranches_per_cell = [cell.total_nbranches for cell in cells]\n        self.total_nbranches = sum(self.nbranches_per_cell)\n        self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell)\n\n        self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n        self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n        self.nodes[\"global_branch_index\"] = np.repeat(\n            np.arange(self.total_nbranches), self.ncomp_per_branch\n        ).tolist()\n        self.nodes[\"global_cell_index\"] = list(\n            itertools.chain(\n                *[[i] * int(cell.cumsum_ncomp[-1]) for i, cell in enumerate(cells)]\n            )\n        )\n        self._update_local_indices()\n        self._init_view()\n\n        parents = [cell.comb_parents for cell in cells]\n        self.comb_parents = jnp.concatenate(\n            [p.at[1:].add(self._cumsum_nbranches[i]) for i, p in enumerate(parents)]\n        )\n\n        # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n        # branch, apart from those branches which do not have a parent (i.e.\n        # -1 in parents). For every branch, tracks the global index of that branch\n        # (`child_branch_index`) and the global index of its parent\n        # (`parent_branch_index`).\n        self.branch_edges = pd.DataFrame(\n            dict(\n                parent_branch_index=self.comb_parents[self.comb_parents != -1],\n                child_branch_index=np.where(self.comb_parents != -1)[0],\n            )\n        )\n\n        # For morphology indexing of both `jax.sparse` and the custom `jaxley` solvers.\n        self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n            compute_children_and_parents(self.branch_edges)\n        )\n\n        # `nbranchpoints` in each cell == cell._par_inds (because `par_inds` are unique).\n        nbranchpoints = jnp.asarray([len(cell._par_inds) for cell in cells])\n        self._cumsum_nbranchpoints_per_cell = cumsum_leading_zero(nbranchpoints)\n\n        # Channels.\n        self._gather_channels_from_constituents(cells)\n\n        self._initialize()\n        del self._cells_list\n\n    def __repr__(self):\n        return f\"{type(self).__name__} with {len(self.channels)} different channels and {len(self.synapses)} synapses. Use `.nodes` or `.edges` for details.\"\n\n    def _init_morph_jaxley_spsolve(self):\n        branchpoint_group_inds = build_branchpoint_group_inds(\n            len(self._par_inds),\n            self._child_belongs_to_branchpoint,\n            self.cumsum_ncomp[-1],\n        )\n        children_in_level = merge_cells(\n            self._cumsum_nbranches,\n            self._cumsum_nbranchpoints_per_cell,\n            [cell._solve_indexer.children_in_level for cell in self._cells_list],\n            exclude_first=False,\n        )\n        parents_in_level = merge_cells(\n            self._cumsum_nbranches,\n            self._cumsum_nbranchpoints_per_cell,\n            [cell._solve_indexer.parents_in_level for cell in self._cells_list],\n            exclude_first=False,\n        )\n        padded_cumsum_ncomp = cumsum_leading_zero(\n            np.concatenate(\n                [np.diff(cell._solve_indexer.cumsum_ncomp) for cell in self._cells_list]\n            )\n        )\n\n        # Generate mapping to dealing with the masking which allows using the custom\n        # sparse solver to deal with different ncomp per branch.\n        remapped_node_indices = remap_index_to_masked(\n            self._internal_node_inds,\n            self.nodes,\n            padded_cumsum_ncomp,\n            self.ncomp_per_branch,\n        )\n        self._solve_indexer = JaxleySolveIndexer(\n            cumsum_ncomp=padded_cumsum_ncomp,\n            branchpoint_group_inds=branchpoint_group_inds,\n            children_in_level=children_in_level,\n            parents_in_level=parents_in_level,\n            root_inds=self._cumsum_nbranches[:-1],\n            remapped_node_indices=remapped_node_indices,\n        )\n\n    def _init_morph_jax_spsolve(self):\n        \"\"\"Initialize the morphology for networks.\n\n        The reason that this function is a bit involved for a `Network` is that Jaxley\n        considers branchpoint nodes to be at the very end of __all__ nodes (i.e. the\n        branchpoints of the first cell are even after the compartments of the second\n        cell. The reason for this is that, otherwise, `cumsum_ncomp` becomes tricky).\n\n        To achieve this, we first loop over all compartments and append them, and then\n        loop over all branchpoints and append those. The code for building the indices\n        from the `comp_edges` is identical to `jx.Cell`.\n\n        Explanation of `self._comp_eges['type']`:\n        `type == 0`: compartment <--> compartment (within branch)\n        `type == 1`: branchpoint --> parent-compartment\n        `type == 2`: branchpoint --> child-compartment\n        `type == 3`: parent-compartment --> branchpoint\n        `type == 4`: child-compartment --> branchpoint\n        \"\"\"\n        self._cumsum_ncomp_per_cell = cumsum_leading_zero(\n            jnp.asarray([cell.cumsum_ncomp[-1] for cell in self.cells])\n        )\n        self._comp_edges = pd.DataFrame()\n\n        # Add all the internal nodes.\n        for offset, cell in zip(self._cumsum_ncomp_per_cell, self._cells_list):\n            condition = cell._comp_edges[\"type\"].to_numpy() == 0\n            rows = cell._comp_edges[condition]\n            self._comp_edges = pd.concat(\n                [self._comp_edges, [offset, offset, 0] + rows], ignore_index=True\n            )\n\n        # All branchpoint-to-compartment nodes.\n        start_branchpoints = self.cumsum_ncomp[-1]  # Index of the first branchpoint.\n        for offset, offset_branchpoints, cell in zip(\n            self._cumsum_ncomp_per_cell,\n            self._cumsum_nbranchpoints_per_cell,\n            self._cells_list,\n        ):\n            offset_within_cell = cell.cumsum_ncomp[-1]\n            condition = cell._comp_edges[\"type\"].isin([1, 2])\n            rows = cell._comp_edges[condition]\n            self._comp_edges = pd.concat(\n                [\n                    self._comp_edges,\n                    [\n                        start_branchpoints - offset_within_cell + offset_branchpoints,\n                        offset,\n                        0,\n                    ]\n                    + rows,\n                ],\n                ignore_index=True,\n            )\n\n        # All compartment-to-branchpoint nodes.\n        for offset, offset_branchpoints, cell in zip(\n            self._cumsum_ncomp_per_cell,\n            self._cumsum_nbranchpoints_per_cell,\n            self._cells_list,\n        ):\n            offset_within_cell = cell.cumsum_ncomp[-1]\n            condition = cell._comp_edges[\"type\"].isin([3, 4])\n            rows = cell._comp_edges[condition]\n            self._comp_edges = pd.concat(\n                [\n                    self._comp_edges,\n                    [\n                        offset,\n                        start_branchpoints - offset_within_cell + offset_branchpoints,\n                        0,\n                    ]\n                    + rows,\n                ],\n                ignore_index=True,\n            )\n\n        # Convert comp_edges to the index format required for `jax.sparse` solvers.\n        n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n        self._n_nodes = n_nodes\n        self._data_inds = data_inds\n        self._indices_jax_spsolve = indices\n        self._indptr_jax_spsolve = indptr\n\n    def _step_synapse(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n        \"\"\"Perform one step of the synapses and obtain their currents.\"\"\"\n        states = self._step_synapse_state(states, syn_channels, params, delta_t, edges)\n        states, current_terms = self._synapse_currents(\n            states, syn_channels, params, delta_t, edges\n        )\n        return states, current_terms\n\n    def _step_synapse_state(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Dict:\n        voltages = states[\"v\"]\n\n        grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n        pre_syn_inds = grouped_syns[\"pre_global_comp_index\"].apply(list)\n        post_syn_inds = grouped_syns[\"post_global_comp_index\"].apply(list)\n        synapse_names = list(grouped_syns.indices.keys())\n\n        for i, synapse_type in enumerate(syn_channels):\n            assert (\n                synapse_names[i] == synapse_type._name\n            ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n            synapse_param_names = list(synapse_type.synapse_params.keys())\n            synapse_state_names = list(synapse_type.synapse_states.keys())\n\n            synapse_params = {}\n            for p in synapse_param_names:\n                synapse_params[p] = params[p]\n            synapse_states = {}\n            for s in synapse_state_names:\n                synapse_states[s] = states[s]\n\n            pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n            post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n            # State updates.\n            states_updated = synapse_type.update_states(\n                synapse_states,\n                delta_t,\n                voltages[pre_inds],\n                voltages[post_inds],\n                synapse_params,\n            )\n\n            # Rebuild state.\n            for key, val in states_updated.items():\n                states[key] = val\n\n        return states\n\n    def _synapse_currents(\n        self,\n        states: Dict,\n        syn_channels: List,\n        params: Dict,\n        delta_t: float,\n        edges: pd.DataFrame,\n    ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n        voltages = states[\"v\"]\n\n        grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n        pre_syn_inds = grouped_syns[\"pre_global_comp_index\"].apply(list)\n        post_syn_inds = grouped_syns[\"post_global_comp_index\"].apply(list)\n        synapse_names = list(grouped_syns.indices.keys())\n\n        syn_voltage_terms = jnp.zeros_like(voltages)\n        syn_constant_terms = jnp.zeros_like(voltages)\n        # Run with two different voltages that are `diff` apart to infer the slope and\n        # offset.\n        diff = 1e-3\n        for i, synapse_type in enumerate(syn_channels):\n            assert (\n                synapse_names[i] == synapse_type._name\n            ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n            synapse_param_names = list(synapse_type.synapse_params.keys())\n            synapse_state_names = list(synapse_type.synapse_states.keys())\n\n            synapse_params = {}\n            for p in synapse_param_names:\n                synapse_params[p] = params[p]\n            synapse_states = {}\n            for s in synapse_state_names:\n                synapse_states[s] = states[s]\n\n            # Get pre and post indexes of the current synapse type.\n            pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n            post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n            # Compute slope and offset of the current through every synapse.\n            pre_v_and_perturbed = jnp.stack(\n                [voltages[pre_inds], voltages[pre_inds] + diff]\n            )\n            post_v_and_perturbed = jnp.stack(\n                [voltages[post_inds], voltages[post_inds] + diff]\n            )\n            synapse_currents = vmap(\n                synapse_type.compute_current, in_axes=(None, 0, 0, None)\n            )(\n                synapse_states,\n                pre_v_and_perturbed,\n                post_v_and_perturbed,\n                synapse_params,\n            )\n            synapse_currents_dist = convert_point_process_to_distributed(\n                synapse_currents,\n                params[\"radius\"][post_inds],\n                params[\"length\"][post_inds],\n            )\n\n            # Split into voltage and constant terms.\n            voltage_term = (synapse_currents_dist[1] - synapse_currents_dist[0]) / diff\n            constant_term = (\n                synapse_currents_dist[0] - voltage_term * voltages[post_inds]\n            )\n\n            # Gather slope and offset for every postsynaptic compartment.\n            gathered_syn_currents = gather_synapes(\n                len(voltages),\n                post_inds,\n                voltage_term,\n                constant_term,\n            )\n            syn_voltage_terms += gathered_syn_currents[0]\n            syn_constant_terms -= gathered_syn_currents[1]\n\n            # Add the synaptic currents through every compartment as state.\n            # `post_syn_currents` is a `jnp.ndarray` of as many elements as there are\n            # compartments in the network.\n            # `[0]` because we only use the non-perturbed voltage.\n            states[f\"i_{synapse_type._name}\"] = synapse_currents[0]\n\n        return states, (syn_voltage_terms, syn_constant_terms)\n\n    def arrange_in_layers(\n        self,\n        layers: List[int],\n        within_layer_offset: float = 500.0,\n        between_layer_offset: float = 1500.0,\n        vertical_layers: bool = False,\n    ):\n        \"\"\"Arrange the cells in the network to form layers.\n\n        Moves the cells in the network to arrange them into layers.\n\n        Args:\n            layers: List of integers specifying the number of cells in each layer.\n            within_layer_offset: Offset between cells within the same layer.\n            between_layer_offset: Offset between layers.\n            vertical_layers: If True, layers are arranged vertically.\n        \"\"\"\n        assert (\n            np.sum(layers) == self.shape[0]\n        ), \"The number of cells in the layers must match the number of cells in the network.\"\n        cells_in_layers = [\n            list(range(sum(layers[:i]), sum(layers[: i + 1])))\n            for i in range(len(layers))\n        ]\n\n        for l, cell_inds in enumerate(cells_in_layers):\n            layer = self.cell(cell_inds)\n            for i, cell in enumerate(layer.cells):\n                if vertical_layers:\n                    x_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset\n                    y_offset = (len(layers) - 1 - l) * between_layer_offset\n                else:\n                    x_offset = l * between_layer_offset\n                    y_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset\n\n                cell.move_to(x=x_offset, y=y_offset, z=0)\n\n    def vis(\n        self,\n        detail: str = \"full\",\n        ax: Optional[Axes] = None,\n        color: str = \"k\",\n        synapse_color: str = \"b\",\n        dims: Tuple[int] = (0, 1),\n        cell_plot_kwargs: Dict = {},\n        synapse_plot_kwargs: Dict = {},\n        synapse_scatter_kwargs: Dict = {},\n        **kwargs,  # absorb add. kwargs, i.e. to enable net.cell(0).vis(type=\"line\")\n    ) -> Axes:\n        \"\"\"Visualize the module.\n\n        Args:\n            detail: Either of [point, full]. `point` visualizes every neuron in the\n                network as a dot.\n                `full` plots the full morphology of every neuron. It requires that\n                `compute_xyz()` has been run.\n            color: The color in which cells are plotted.\n            synapse_color: The color in which synapses are plotted.\n            dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n                two of them.\n            cell_plot_kwargs: Keyword arguments passed to the plotting function for\n                cell morphologies. Only takes effect for `detail='full'`.\n            synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n                syanpses.\n            synapse_scatter_kwargs: Keyword arguments passed to the scatter function for\n                syanpse terminals.\n        \"\"\"\n        xyz0 = self.cell(0).xyzr[0][:, :3]\n        same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells])\n        if same_xyz:\n            warn(\n                \"Same coordinates for all cells. Consider using `move`, `move_to` or `arrange_in_layers` to move them.\"\n            )\n\n        if ax is None:\n            fig = plt.figure(figsize=(3, 3))\n            ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n        # detail=\"point\" -> pos taken to be the mean of all traced points on the cell.\n        cell_to_point_xyz = lambda cell: np.mean(np.vstack(cell.xyzr)[:, :3], axis=0)\n\n        dims_np = np.asarray(dims)\n        if detail == \"point\":\n            for cell in self.cells:\n                pos = cell_to_point_xyz(cell)[dims_np]\n                ax.scatter(*pos, color=color, **cell_plot_kwargs)\n        elif detail == \"full\":\n            ax = super().vis(dims=dims, color=color, ax=ax, **cell_plot_kwargs)\n        else:\n            raise ValueError(\"detail must be in {full, point}.\")\n\n        nodes = self.nodes.set_index(\"global_comp_index\")\n        for i, edge in self.edges.iterrows():\n            prepost_locs = []\n            for prepost in [\"pre\", \"post\"]:\n                loc, comp = edge[[prepost + \"_locs\", prepost + \"_global_comp_index\"]]\n                branch = nodes.loc[comp, \"global_branch_index\"]\n                cell = nodes.loc[comp, \"global_cell_index\"]\n                branch_xyz = self.xyzr[branch][:, :3]\n\n                xyz_loc = branch_xyz\n                if detail == \"point\":\n                    xyz_loc = cell_to_point_xyz(self.cell(cell))\n                elif len(branch_xyz) == 2:\n                    # If only start and end point of a branch are traced, perform a\n                    # linear interpolation to get the synpase location.\n                    xyz_loc = branch_xyz[0] + (branch_xyz[1] - branch_xyz[0]) * loc\n                else:\n                    # If densely traced, use intermediate trace values for synapse loc.\n                    middle_ind = int((len(branch_xyz) - 1) * loc)\n                    xyz_loc = xyz_loc[middle_ind]\n\n                prepost_locs.append(xyz_loc)\n            prepost_locs = np.stack(prepost_locs).T\n            ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs)\n            ax.scatter(\n                *prepost_locs[dims_np, 1], color=synapse_color, **synapse_scatter_kwargs\n            )\n\n        return ax\n\n    def _infer_synapse_type_ind(self, synapse_name):\n        syn_names = self.base.synapse_names\n        is_new_type = False if synapse_name in syn_names else True\n        type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name)\n        return type_ind, is_new_type\n\n    def _update_synapse_state_names(self, synapse_type):\n        # (Potentially) update variables that track meta information about synapses.\n        self.base.synapse_names.append(synapse_type._name)\n        self.base.synapse_param_names += list(synapse_type.synapse_params.keys())\n        self.base.synapse_state_names += list(synapse_type.synapse_states.keys())\n        self.base.synapses.append(synapse_type)\n\n    def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):\n        # Add synapse types to the module and infer their unique identifier.\n        synapse_name = synapse_type._name\n        synapse_current_name = f\"i_{synapse_name}\"\n        type_ind, is_new = self._infer_synapse_type_ind(synapse_name)\n        if is_new:  # synapse is not known\n            self._update_synapse_state_names(synapse_type)\n            self.base.synapse_current_names.append(synapse_current_name)\n\n        index = len(self.base.edges)\n        indices = [idx for idx in range(index, index + len(pre_nodes))]\n        global_edge_index = pd.DataFrame({\"global_edge_index\": indices})\n        post_loc = loc_of_index(\n            post_nodes[\"global_comp_index\"].to_numpy(),\n            post_nodes[\"global_branch_index\"].to_numpy(),\n            self.ncomp_per_branch,\n        )\n        pre_loc = loc_of_index(\n            pre_nodes[\"global_comp_index\"].to_numpy(),\n            pre_nodes[\"global_branch_index\"].to_numpy(),\n            self.ncomp_per_branch,\n        )\n\n        # Define new synapses. Each row is one synapse.\n        pre_nodes = pre_nodes[[\"global_comp_index\"]]\n        pre_nodes.columns = [\"pre_global_comp_index\"]\n        post_nodes = post_nodes[[\"global_comp_index\"]]\n        post_nodes.columns = [\"post_global_comp_index\"]\n        new_rows = pd.concat(\n            [\n                global_edge_index,\n                pre_nodes.reset_index(drop=True),\n                post_nodes.reset_index(drop=True),\n            ],\n            axis=1,\n        )\n        new_rows[\"type\"] = synapse_name\n        new_rows[\"type_ind\"] = type_ind\n        new_rows[\"pre_locs\"] = pre_loc\n        new_rows[\"post_locs\"] = post_loc\n        self.base.edges = concat_and_ignore_empty(\n            [self.base.edges, new_rows], ignore_index=True, axis=0\n        )\n        self._add_params_to_edges(synapse_type, indices)\n        self.base.edges[\"controlled_by_param\"] = 0\n        self._edges_in_view = self.edges.index.to_numpy()\n\n    def _add_params_to_edges(self, synapse_type, indices):\n        # Add parameters and states to the `.edges` table.\n        for key, param_val in synapse_type.synapse_params.items():\n            self.base.edges.loc[indices, key] = param_val\n\n        # Update synaptic state array.\n        for key, state_val in synapse_type.synapse_states.items():\n            self.base.edges.loc[indices, key] = state_val\n
"},{"location":"reference/modules/#jaxley.modules.network.Network.__init__","title":"__init__(cells)","text":"

Initialize network of cells and synapses.

Parameters:

Name Type Description Default cells List[Cell]

A list of cells that make up the network.

required Source code in jaxley/modules/network.py
def __init__(\n    self,\n    cells: List[Cell],\n):\n    \"\"\"Initialize network of cells and synapses.\n\n    Args:\n        cells: A list of cells that make up the network.\n    \"\"\"\n    super().__init__()\n    for cell in cells:\n        self.xyzr += deepcopy(cell.xyzr)\n\n    self._cells_list = cells\n    self.ncomp_per_branch = np.concatenate(\n        [cell.ncomp_per_branch for cell in cells]\n    )\n    self.ncomp = int(np.max(self.ncomp_per_branch))\n    self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n    self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n    self._append_params_and_states(self.network_params, self.network_states)\n\n    self.nbranches_per_cell = [cell.total_nbranches for cell in cells]\n    self.total_nbranches = sum(self.nbranches_per_cell)\n    self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell)\n\n    self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n    self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n    self.nodes[\"global_branch_index\"] = np.repeat(\n        np.arange(self.total_nbranches), self.ncomp_per_branch\n    ).tolist()\n    self.nodes[\"global_cell_index\"] = list(\n        itertools.chain(\n            *[[i] * int(cell.cumsum_ncomp[-1]) for i, cell in enumerate(cells)]\n        )\n    )\n    self._update_local_indices()\n    self._init_view()\n\n    parents = [cell.comb_parents for cell in cells]\n    self.comb_parents = jnp.concatenate(\n        [p.at[1:].add(self._cumsum_nbranches[i]) for i, p in enumerate(parents)]\n    )\n\n    # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n    # branch, apart from those branches which do not have a parent (i.e.\n    # -1 in parents). For every branch, tracks the global index of that branch\n    # (`child_branch_index`) and the global index of its parent\n    # (`parent_branch_index`).\n    self.branch_edges = pd.DataFrame(\n        dict(\n            parent_branch_index=self.comb_parents[self.comb_parents != -1],\n            child_branch_index=np.where(self.comb_parents != -1)[0],\n        )\n    )\n\n    # For morphology indexing of both `jax.sparse` and the custom `jaxley` solvers.\n    self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n        compute_children_and_parents(self.branch_edges)\n    )\n\n    # `nbranchpoints` in each cell == cell._par_inds (because `par_inds` are unique).\n    nbranchpoints = jnp.asarray([len(cell._par_inds) for cell in cells])\n    self._cumsum_nbranchpoints_per_cell = cumsum_leading_zero(nbranchpoints)\n\n    # Channels.\n    self._gather_channels_from_constituents(cells)\n\n    self._initialize()\n    del self._cells_list\n
"},{"location":"reference/modules/#jaxley.modules.network.Network.arrange_in_layers","title":"arrange_in_layers(layers, within_layer_offset=500.0, between_layer_offset=1500.0, vertical_layers=False)","text":"

Arrange the cells in the network to form layers.

Moves the cells in the network to arrange them into layers.

Parameters:

Name Type Description Default layers List[int]

List of integers specifying the number of cells in each layer.

required within_layer_offset float

Offset between cells within the same layer.

500.0 between_layer_offset float

Offset between layers.

1500.0 vertical_layers bool

If True, layers are arranged vertically.

False Source code in jaxley/modules/network.py
def arrange_in_layers(\n    self,\n    layers: List[int],\n    within_layer_offset: float = 500.0,\n    between_layer_offset: float = 1500.0,\n    vertical_layers: bool = False,\n):\n    \"\"\"Arrange the cells in the network to form layers.\n\n    Moves the cells in the network to arrange them into layers.\n\n    Args:\n        layers: List of integers specifying the number of cells in each layer.\n        within_layer_offset: Offset between cells within the same layer.\n        between_layer_offset: Offset between layers.\n        vertical_layers: If True, layers are arranged vertically.\n    \"\"\"\n    assert (\n        np.sum(layers) == self.shape[0]\n    ), \"The number of cells in the layers must match the number of cells in the network.\"\n    cells_in_layers = [\n        list(range(sum(layers[:i]), sum(layers[: i + 1])))\n        for i in range(len(layers))\n    ]\n\n    for l, cell_inds in enumerate(cells_in_layers):\n        layer = self.cell(cell_inds)\n        for i, cell in enumerate(layer.cells):\n            if vertical_layers:\n                x_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset\n                y_offset = (len(layers) - 1 - l) * between_layer_offset\n            else:\n                x_offset = l * between_layer_offset\n                y_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset\n\n            cell.move_to(x=x_offset, y=y_offset, z=0)\n
"},{"location":"reference/modules/#jaxley.modules.network.Network.vis","title":"vis(detail='full', ax=None, color='k', synapse_color='b', dims=(0, 1), cell_plot_kwargs={}, synapse_plot_kwargs={}, synapse_scatter_kwargs={}, **kwargs)","text":"

Visualize the module.

Parameters:

Name Type Description Default detail str

Either of [point, full]. point visualizes every neuron in the network as a dot. full plots the full morphology of every neuron. It requires that compute_xyz() has been run.

'full' color str

The color in which cells are plotted.

'k' synapse_color str

The color in which synapses are plotted.

'b' dims Tuple[int]

Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

(0, 1) cell_plot_kwargs Dict

Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for detail='full'.

{} synapse_plot_kwargs Dict

Keyword arguments passed to the plotting function for syanpses.

{} synapse_scatter_kwargs Dict

Keyword arguments passed to the scatter function for syanpse terminals.

{} Source code in jaxley/modules/network.py
def vis(\n    self,\n    detail: str = \"full\",\n    ax: Optional[Axes] = None,\n    color: str = \"k\",\n    synapse_color: str = \"b\",\n    dims: Tuple[int] = (0, 1),\n    cell_plot_kwargs: Dict = {},\n    synapse_plot_kwargs: Dict = {},\n    synapse_scatter_kwargs: Dict = {},\n    **kwargs,  # absorb add. kwargs, i.e. to enable net.cell(0).vis(type=\"line\")\n) -> Axes:\n    \"\"\"Visualize the module.\n\n    Args:\n        detail: Either of [point, full]. `point` visualizes every neuron in the\n            network as a dot.\n            `full` plots the full morphology of every neuron. It requires that\n            `compute_xyz()` has been run.\n        color: The color in which cells are plotted.\n        synapse_color: The color in which synapses are plotted.\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two of them.\n        cell_plot_kwargs: Keyword arguments passed to the plotting function for\n            cell morphologies. Only takes effect for `detail='full'`.\n        synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n            syanpses.\n        synapse_scatter_kwargs: Keyword arguments passed to the scatter function for\n            syanpse terminals.\n    \"\"\"\n    xyz0 = self.cell(0).xyzr[0][:, :3]\n    same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells])\n    if same_xyz:\n        warn(\n            \"Same coordinates for all cells. Consider using `move`, `move_to` or `arrange_in_layers` to move them.\"\n        )\n\n    if ax is None:\n        fig = plt.figure(figsize=(3, 3))\n        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n    # detail=\"point\" -> pos taken to be the mean of all traced points on the cell.\n    cell_to_point_xyz = lambda cell: np.mean(np.vstack(cell.xyzr)[:, :3], axis=0)\n\n    dims_np = np.asarray(dims)\n    if detail == \"point\":\n        for cell in self.cells:\n            pos = cell_to_point_xyz(cell)[dims_np]\n            ax.scatter(*pos, color=color, **cell_plot_kwargs)\n    elif detail == \"full\":\n        ax = super().vis(dims=dims, color=color, ax=ax, **cell_plot_kwargs)\n    else:\n        raise ValueError(\"detail must be in {full, point}.\")\n\n    nodes = self.nodes.set_index(\"global_comp_index\")\n    for i, edge in self.edges.iterrows():\n        prepost_locs = []\n        for prepost in [\"pre\", \"post\"]:\n            loc, comp = edge[[prepost + \"_locs\", prepost + \"_global_comp_index\"]]\n            branch = nodes.loc[comp, \"global_branch_index\"]\n            cell = nodes.loc[comp, \"global_cell_index\"]\n            branch_xyz = self.xyzr[branch][:, :3]\n\n            xyz_loc = branch_xyz\n            if detail == \"point\":\n                xyz_loc = cell_to_point_xyz(self.cell(cell))\n            elif len(branch_xyz) == 2:\n                # If only start and end point of a branch are traced, perform a\n                # linear interpolation to get the synpase location.\n                xyz_loc = branch_xyz[0] + (branch_xyz[1] - branch_xyz[0]) * loc\n            else:\n                # If densely traced, use intermediate trace values for synapse loc.\n                middle_ind = int((len(branch_xyz) - 1) * loc)\n                xyz_loc = xyz_loc[middle_ind]\n\n            prepost_locs.append(xyz_loc)\n        prepost_locs = np.stack(prepost_locs).T\n        ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs)\n        ax.scatter(\n            *prepost_locs[dims_np, 1], color=synapse_color, **synapse_scatter_kwargs\n        )\n\n    return ax\n
"},{"location":"reference/optimize/","title":"Optimization","text":""},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer","title":"TypeOptimizer","text":"

optax wrapper which allows different argument values for different params.

Source code in jaxley/optimize/optimizer.py
class TypeOptimizer:\n    \"\"\"`optax` wrapper which allows different argument values for different params.\"\"\"\n\n    def __init__(\n        self,\n        optimizer: Callable,\n        optimizer_args: Dict[str, Any],\n        opt_params: List[Dict[str, jnp.ndarray]],\n    ):\n        \"\"\"Create the optimizers.\n\n        This requires access to `opt_params` in order to know how many optimizers\n        should be created. It creates `len(opt_params)` optimizers.\n\n        Example usage:\n        ```\n        lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n        optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n        opt_state = optimizer.init(opt_params)\n        ```\n\n        ```\n        optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n        optimizer = TypeOptimizer(\n            lambda args: optax.sgd(args[0], momentum=args[1]),\n            optimizer_args,\n            opt_params\n        )\n        opt_state = optimizer.init(opt_params)\n        ```\n\n        Args:\n            optimizer: A Callable that takes the learning rate and returns the\n                `optax.optimizer` which should be used.\n            optimizer_args: The arguments for different kinds of parameters.\n                Each item of the dictionary will be passed to the `Callable` passed to\n                `optimizer`.\n            opt_params: The parameters to be optimized. The exact values are not used,\n                only the number of elements in the list and the key of each dict.\n        \"\"\"\n        self.base_optimizer = optimizer\n\n        self.optimizers = []\n        for params in opt_params:\n            names = list(params.keys())\n            assert len(names) == 1, \"Multiple parameters were added at once.\"\n            name = names[0]\n            optimizer = self.base_optimizer(optimizer_args[name])\n            self.optimizers.append({name: optimizer})\n\n    def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n        \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n        opt_states = []\n        for params, optimizer in zip(opt_params, self.optimizers):\n            name = list(optimizer.keys())[0]\n            opt_state = optimizer[name].init(params)\n            opt_states.append(opt_state)\n        return opt_states\n\n    def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n        \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n        all_updates = []\n        new_opt_states = []\n        for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n            name = list(opt.keys())[0]\n            updates, new_opt_state = opt[name].update(grad, state)\n            all_updates.append(updates)\n            new_opt_states.append(new_opt_state)\n        return all_updates, new_opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.__init__","title":"__init__(optimizer, optimizer_args, opt_params)","text":"

Create the optimizers.

This requires access to opt_params in order to know how many optimizers should be created. It creates len(opt_params) optimizers.

Example usage:

lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\noptimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\nopt_state = optimizer.init(opt_params)\n

optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\noptimizer = TypeOptimizer(\n    lambda args: optax.sgd(args[0], momentum=args[1]),\n    optimizer_args,\n    opt_params\n)\nopt_state = optimizer.init(opt_params)\n

Parameters:

Name Type Description Default optimizer Callable

A Callable that takes the learning rate and returns the optax.optimizer which should be used.

required optimizer_args Dict[str, Any]

The arguments for different kinds of parameters. Each item of the dictionary will be passed to the Callable passed to optimizer.

required opt_params List[Dict[str, ndarray]]

The parameters to be optimized. The exact values are not used, only the number of elements in the list and the key of each dict.

required Source code in jaxley/optimize/optimizer.py
def __init__(\n    self,\n    optimizer: Callable,\n    optimizer_args: Dict[str, Any],\n    opt_params: List[Dict[str, jnp.ndarray]],\n):\n    \"\"\"Create the optimizers.\n\n    This requires access to `opt_params` in order to know how many optimizers\n    should be created. It creates `len(opt_params)` optimizers.\n\n    Example usage:\n    ```\n    lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n    optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n    opt_state = optimizer.init(opt_params)\n    ```\n\n    ```\n    optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n    optimizer = TypeOptimizer(\n        lambda args: optax.sgd(args[0], momentum=args[1]),\n        optimizer_args,\n        opt_params\n    )\n    opt_state = optimizer.init(opt_params)\n    ```\n\n    Args:\n        optimizer: A Callable that takes the learning rate and returns the\n            `optax.optimizer` which should be used.\n        optimizer_args: The arguments for different kinds of parameters.\n            Each item of the dictionary will be passed to the `Callable` passed to\n            `optimizer`.\n        opt_params: The parameters to be optimized. The exact values are not used,\n            only the number of elements in the list and the key of each dict.\n    \"\"\"\n    self.base_optimizer = optimizer\n\n    self.optimizers = []\n    for params in opt_params:\n        names = list(params.keys())\n        assert len(names) == 1, \"Multiple parameters were added at once.\"\n        name = names[0]\n        optimizer = self.base_optimizer(optimizer_args[name])\n        self.optimizers.append({name: optimizer})\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.init","title":"init(opt_params)","text":"

Initialize the optimizers. Equivalent to optax.optimizers.init().

Source code in jaxley/optimize/optimizer.py
def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n    \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n    opt_states = []\n    for params, optimizer in zip(opt_params, self.optimizers):\n        name = list(optimizer.keys())[0]\n        opt_state = optimizer[name].init(params)\n        opt_states.append(opt_state)\n    return opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.update","title":"update(gradient, opt_state)","text":"

Update the optimizers. Equivalent to optax.optimizers.update().

Source code in jaxley/optimize/optimizer.py
def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n    \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n    all_updates = []\n    new_opt_states = []\n    for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n        name = list(opt.keys())[0]\n        updates, new_opt_state = opt[name].update(grad, state)\n        all_updates.append(updates)\n        new_opt_states.append(new_opt_state)\n    return all_updates, new_opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.AffineTransform","title":"AffineTransform","text":"

Bases: Transform

Source code in jaxley/optimize/transforms.py
class AffineTransform(Transform):\n    def __init__(self, scale: ArrayLike, shift: ArrayLike):\n        \"\"\"This transform rescales and shifts the input.\n\n        Args:\n            scale (ArrayLike): Scaling factor.\n            shift (ArrayLike): Additive shift.\n\n        Raises:\n            ValueError: Scale needs to be larger than 0\n        \"\"\"\n        if jnp.allclose(scale, 0):\n            raise ValueError(\"a cannot be zero, must be invertible\")\n        self.a = scale\n        self.b = shift\n\n    def forward(self, x: ArrayLike) -> Array:\n        return self.a * x + self.b\n\n    def inverse(self, x: ArrayLike) -> Array:\n        return (x - self.b) / self.a\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.AffineTransform.__init__","title":"__init__(scale, shift)","text":"

This transform rescales and shifts the input.

Parameters:

Name Type Description Default scale ArrayLike

Scaling factor.

required shift ArrayLike

Additive shift.

required

Raises:

Type Description ValueError

Scale needs to be larger than 0

Source code in jaxley/optimize/transforms.py
def __init__(self, scale: ArrayLike, shift: ArrayLike):\n    \"\"\"This transform rescales and shifts the input.\n\n    Args:\n        scale (ArrayLike): Scaling factor.\n        shift (ArrayLike): Additive shift.\n\n    Raises:\n        ValueError: Scale needs to be larger than 0\n    \"\"\"\n    if jnp.allclose(scale, 0):\n        raise ValueError(\"a cannot be zero, must be invertible\")\n    self.a = scale\n    self.b = shift\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ChainTransform","title":"ChainTransform","text":"

Bases: Transform

Chaining together multiple transformations

Source code in jaxley/optimize/transforms.py
class ChainTransform(Transform):\n    \"\"\"Chaining together multiple transformations\"\"\"\n\n    def __init__(self, transforms: Sequence[Transform]) -> None:\n        \"\"\"A chain of transformations\n\n        Args:\n            transforms (Sequence[Transform]): Transforms to apply\n        \"\"\"\n        super().__init__()\n        self.transforms = transforms\n\n    def forward(self, x: ArrayLike) -> Array:\n        for transform in self.transforms:\n            x = transform(x)\n        return x\n\n    def inverse(self, y: ArrayLike) -> Array:\n        for transform in reversed(self.transforms):\n            y = transform.inverse(y)\n        return y\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ChainTransform.__init__","title":"__init__(transforms)","text":"

A chain of transformations

Parameters:

Name Type Description Default transforms Sequence[Transform]

Transforms to apply

required Source code in jaxley/optimize/transforms.py
def __init__(self, transforms: Sequence[Transform]) -> None:\n    \"\"\"A chain of transformations\n\n    Args:\n        transforms (Sequence[Transform]): Transforms to apply\n    \"\"\"\n    super().__init__()\n    self.transforms = transforms\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.CustomTransform","title":"CustomTransform","text":"

Bases: Transform

Custom transformation

Source code in jaxley/optimize/transforms.py
class CustomTransform(Transform):\n    \"\"\"Custom transformation\"\"\"\n\n    def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:\n        \"\"\"A custom transformation using a user-defined froward and\n        inverse function\n\n        Args:\n            forward_fn (Callable): Forward transformation\n            inverse_fn (Callable): Inverse transformation\n        \"\"\"\n        super().__init__()\n        self.forward_fn = forward_fn\n        self.inverse_fn = inverse_fn\n\n    def forward(self, x: ArrayLike) -> Array:\n        return self.forward_fn(x)\n\n    def inverse(self, y: ArrayLike) -> Array:\n        return self.inverse_fn(y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.CustomTransform.__init__","title":"__init__(forward_fn, inverse_fn)","text":"

A custom transformation using a user-defined froward and inverse function

Parameters:

Name Type Description Default forward_fn Callable

Forward transformation

required inverse_fn Callable

Inverse transformation

required Source code in jaxley/optimize/transforms.py
def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:\n    \"\"\"A custom transformation using a user-defined froward and\n    inverse function\n\n    Args:\n        forward_fn (Callable): Forward transformation\n        inverse_fn (Callable): Inverse transformation\n    \"\"\"\n    super().__init__()\n    self.forward_fn = forward_fn\n    self.inverse_fn = inverse_fn\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.MaskedTransform","title":"MaskedTransform","text":"

Bases: Transform

Source code in jaxley/optimize/transforms.py
class MaskedTransform(Transform):\n    def __init__(self, mask: ArrayLike, transform: Transform) -> None:\n        \"\"\"A masked transformation\n\n        Args:\n            mask (ArrayLike): Which elements to transform\n            transform (Transform): Transformation to apply\n        \"\"\"\n        super().__init__()\n        self.mask = mask\n        self.transform = transform\n\n    def forward(self, x: ArrayLike) -> Array:\n        return jnp.where(self.mask, self.transform.forward(x), x)\n\n    def inverse(self, y: ArrayLike) -> Array:\n        return jnp.where(self.mask, self.transform.inverse(y), y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.MaskedTransform.__init__","title":"__init__(mask, transform)","text":"

A masked transformation

Parameters:

Name Type Description Default mask ArrayLike

Which elements to transform

required transform Transform

Transformation to apply

required Source code in jaxley/optimize/transforms.py
def __init__(self, mask: ArrayLike, transform: Transform) -> None:\n    \"\"\"A masked transformation\n\n    Args:\n        mask (ArrayLike): Which elements to transform\n        transform (Transform): Transformation to apply\n    \"\"\"\n    super().__init__()\n    self.mask = mask\n    self.transform = transform\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.NegSoftplusTransform","title":"NegSoftplusTransform","text":"

Bases: SoftplusTransform

Negative softplus transformation.

Source code in jaxley/optimize/transforms.py
class NegSoftplusTransform(SoftplusTransform):\n    \"\"\"Negative softplus transformation.\"\"\"\n\n    def __init__(self, upper: ArrayLike) -> None:\n        \"\"\"This transform maps any value bijectively to the interval (-inf, upper].\n\n        Args:\n            upper (ArrayLike): Upper bound of the interval.\n        \"\"\"\n        super().__init__(upper)\n\n    def forward(self, x: ArrayLike) -> Array:\n        return -super().forward(-x)\n\n    def inverse(self, y: ArrayLike) -> Array:\n        return -super().inverse(-y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.NegSoftplusTransform.__init__","title":"__init__(upper)","text":"

This transform maps any value bijectively to the interval (-inf, upper].

Parameters:

Name Type Description Default upper ArrayLike

Upper bound of the interval.

required Source code in jaxley/optimize/transforms.py
def __init__(self, upper: ArrayLike) -> None:\n    \"\"\"This transform maps any value bijectively to the interval (-inf, upper].\n\n    Args:\n        upper (ArrayLike): Upper bound of the interval.\n    \"\"\"\n    super().__init__(upper)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform","title":"ParamTransform","text":"

Parameter transformation utility.

This class is used to transform parameters usually from an unconstrained space to a constrained space and back (bacause most biophysical parameter are bounded). The user can specify a PyTree of transforms that are applied to the parameters.

Attributes:

Name Type Description tf_dict

A PyTree of transforms for each parameter.

Source code in jaxley/optimize/transforms.py
class ParamTransform:\n    \"\"\"Parameter transformation utility.\n\n    This class is used to transform parameters usually from an unconstrained space to a constrained space\n    and back (bacause most biophysical parameter are bounded). The user can specify a PyTree of transforms\n    that are applied to the parameters.\n\n    Attributes:\n        tf_dict: A PyTree of transforms for each parameter.\n\n    \"\"\"\n\n    def __init__(self, tf_dict: List[Dict[str, Transform]] | Transform) -> None:\n        \"\"\"Creates a new ParamTransform object.\n\n        Args:\n            tf_dict: A PyTree of transforms for each parameter.\n        \"\"\"\n\n        self.tf_dict = tf_dict\n\n    def forward(\n        self, params: List[Dict[str, ArrayLike]] | ArrayLike\n    ) -> Dict[str, Array]:\n        \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n        Args:\n            params: A list of dictionaries (or any PyTree) with unconstrained parameters.\n\n        Returns:\n            A list of dictionaries (or any PyTree) with transformed parameters.\n\n        \"\"\"\n\n        return jax.tree_util.tree_map(lambda x, tf: tf.forward(x), params, self.tf_dict)\n\n    def inverse(\n        self, params: List[Dict[str, ArrayLike]] | ArrayLike\n    ) -> Dict[str, Array]:\n        \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n        Args:\n            params: A list of dictionaries (or any PyTree) with transformed parameters.\n\n        Returns:\n            A list of dictionaries (or any PyTree) with unconstrained parameters.\n        \"\"\"\n\n        return jax.tree_util.tree_map(lambda x, tf: tf.inverse(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.__init__","title":"__init__(tf_dict)","text":"

Creates a new ParamTransform object.

Parameters:

Name Type Description Default tf_dict List[Dict[str, Transform]] | Transform

A PyTree of transforms for each parameter.

required Source code in jaxley/optimize/transforms.py
def __init__(self, tf_dict: List[Dict[str, Transform]] | Transform) -> None:\n    \"\"\"Creates a new ParamTransform object.\n\n    Args:\n        tf_dict: A PyTree of transforms for each parameter.\n    \"\"\"\n\n    self.tf_dict = tf_dict\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.forward","title":"forward(params)","text":"

Pushes unconstrained parameters through a tf such that they fit the interval.

Parameters:

Name Type Description Default params List[Dict[str, ArrayLike]] | ArrayLike

A list of dictionaries (or any PyTree) with unconstrained parameters.

required

Returns:

Type Description Dict[str, Array]

A list of dictionaries (or any PyTree) with transformed parameters.

Source code in jaxley/optimize/transforms.py
def forward(\n    self, params: List[Dict[str, ArrayLike]] | ArrayLike\n) -> Dict[str, Array]:\n    \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n    Args:\n        params: A list of dictionaries (or any PyTree) with unconstrained parameters.\n\n    Returns:\n        A list of dictionaries (or any PyTree) with transformed parameters.\n\n    \"\"\"\n\n    return jax.tree_util.tree_map(lambda x, tf: tf.forward(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.inverse","title":"inverse(params)","text":"

Takes parameters from within the interval and makes them unconstrained.

Parameters:

Name Type Description Default params List[Dict[str, ArrayLike]] | ArrayLike

A list of dictionaries (or any PyTree) with transformed parameters.

required

Returns:

Type Description Dict[str, Array]

A list of dictionaries (or any PyTree) with unconstrained parameters.

Source code in jaxley/optimize/transforms.py
def inverse(\n    self, params: List[Dict[str, ArrayLike]] | ArrayLike\n) -> Dict[str, Array]:\n    \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n    Args:\n        params: A list of dictionaries (or any PyTree) with transformed parameters.\n\n    Returns:\n        A list of dictionaries (or any PyTree) with unconstrained parameters.\n    \"\"\"\n\n    return jax.tree_util.tree_map(lambda x, tf: tf.inverse(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SigmoidTransform","title":"SigmoidTransform","text":"

Bases: Transform

Sigmoid transformation.

Source code in jaxley/optimize/transforms.py
class SigmoidTransform(Transform):\n    \"\"\"Sigmoid transformation.\"\"\"\n\n    def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:\n        \"\"\"This transform maps any value bijectively to the interval [lower, upper].\n\n        Args:\n            lower (ArrayLike): Lower bound of the interval.\n            upper (ArrayLike): Upper bound of the interval.\n        \"\"\"\n        super().__init__()\n        self.lower = lower\n        self.width = upper - lower\n\n    def forward(self, x: ArrayLike) -> Array:\n        y = 1.0 / (1.0 + save_exp(-x))\n        return self.lower + self.width * y\n\n    def inverse(self, y: ArrayLike) -> Array:\n        x = (y - self.lower) / self.width\n        x = -jnp.log((1.0 / x) - 1.0)\n        return x\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SigmoidTransform.__init__","title":"__init__(lower, upper)","text":"

This transform maps any value bijectively to the interval [lower, upper].

Parameters:

Name Type Description Default lower ArrayLike

Lower bound of the interval.

required upper ArrayLike

Upper bound of the interval.

required Source code in jaxley/optimize/transforms.py
def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:\n    \"\"\"This transform maps any value bijectively to the interval [lower, upper].\n\n    Args:\n        lower (ArrayLike): Lower bound of the interval.\n        upper (ArrayLike): Upper bound of the interval.\n    \"\"\"\n    super().__init__()\n    self.lower = lower\n    self.width = upper - lower\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SoftplusTransform","title":"SoftplusTransform","text":"

Bases: Transform

Softplus transformation.

Source code in jaxley/optimize/transforms.py
class SoftplusTransform(Transform):\n    \"\"\"Softplus transformation.\"\"\"\n\n    def __init__(self, lower: ArrayLike) -> None:\n        \"\"\"This transform maps any value bijectively to the interval [lower, inf).\n\n        Args:\n            lower (ArrayLike): Lower bound of the interval.\n        \"\"\"\n        super().__init__()\n        self.lower = lower\n\n    def forward(self, x: ArrayLike) -> Array:\n        return jnp.log1p(save_exp(x)) + self.lower\n\n    def inverse(self, y: ArrayLike) -> Array:\n        return jnp.log(save_exp(y - self.lower) - 1.0)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SoftplusTransform.__init__","title":"__init__(lower)","text":"

This transform maps any value bijectively to the interval [lower, inf).

Parameters:

Name Type Description Default lower ArrayLike

Lower bound of the interval.

required Source code in jaxley/optimize/transforms.py
def __init__(self, lower: ArrayLike) -> None:\n    \"\"\"This transform maps any value bijectively to the interval [lower, inf).\n\n    Args:\n        lower (ArrayLike): Lower bound of the interval.\n    \"\"\"\n    super().__init__()\n    self.lower = lower\n
"},{"location":"reference/utils/","title":"Utils","text":""},{"location":"reference/utils/#jaxley.utils.cell_utils.build_radiuses_from_xyzr","title":"build_radiuses_from_xyzr(radius_fns, branch_indices, min_radius, ncomp)","text":"

Return the radiuses of branches given SWC file xyzr.

Returns an array of shape (num_branches, ncomp).

Parameters:

Name Type Description Default radius_fns List[Callable]

Functions which, given compartment locations return the radius.

required branch_indices List[int]

The indices of the branches for which to return the radiuses.

required min_radius Optional[float]

If passed, the radiuses are clipped to be at least as large.

required ncomp int

The number of compartments that every branch is discretized into.

required Source code in jaxley/utils/cell_utils.py
def build_radiuses_from_xyzr(\n    radius_fns: List[Callable],\n    branch_indices: List[int],\n    min_radius: Optional[float],\n    ncomp: int,\n) -> jnp.ndarray:\n    \"\"\"Return the radiuses of branches given SWC file xyzr.\n\n    Returns an array of shape `(num_branches, ncomp)`.\n\n    Args:\n        radius_fns: Functions which, given compartment locations return the radius.\n        branch_indices: The indices of the branches for which to return the radiuses.\n        min_radius: If passed, the radiuses are clipped to be at least as large.\n        ncomp: The number of compartments that every branch is discretized into.\n    \"\"\"\n    # Compartment locations are at the center of the internal nodes.\n    non_split = 1 / ncomp\n    range_ = np.linspace(non_split / 2, 1 - non_split / 2, ncomp)\n\n    # Build radiuses.\n    radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices])\n    radiuses_each = radiuses.ravel(order=\"C\")\n    if min_radius is None:\n        assert np.all(\n            radiuses_each > 0.0\n        ), \"Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`.\"\n    else:\n        radiuses_each[radiuses_each < min_radius] = min_radius\n\n    return radiuses_each\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_axial_conductances","title":"compute_axial_conductances(comp_edges, params)","text":"

Given comp_edges, radius, length, r_a, cm, compute the axial conductances.

Note that the resulting axial conductances will already by divided by the capacitance cm.

Source code in jaxley/utils/cell_utils.py
def compute_axial_conductances(\n    comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray]\n) -> jnp.ndarray:\n    \"\"\"Given `comp_edges`, radius, length, r_a, cm, compute the axial conductances.\n\n    Note that the resulting axial conductances will already by divided by the\n    capacitance `cm`.\n    \"\"\"\n    # `Compartment-to-compartment` (c2c) axial coupling conductances.\n    condition = comp_edges[\"type\"].to_numpy() == 0\n    source_comp_inds = np.asarray(comp_edges[condition][\"source\"].to_list())\n    sink_comp_inds = np.asarray(comp_edges[condition][\"sink\"].to_list())\n\n    if len(sink_comp_inds) > 0:\n        conds_c2c = (\n            vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))(\n                params[\"radius\"][sink_comp_inds],\n                params[\"radius\"][source_comp_inds],\n                params[\"axial_resistivity\"][sink_comp_inds],\n                params[\"axial_resistivity\"][source_comp_inds],\n                params[\"length\"][sink_comp_inds],\n                params[\"length\"][source_comp_inds],\n            )\n            / params[\"capacitance\"][sink_comp_inds]\n        )\n    else:\n        conds_c2c = jnp.asarray([])\n\n    # `branchpoint-to-compartment` (bp2c) axial coupling conductances.\n    condition = comp_edges[\"type\"].isin([1, 2])\n    sink_comp_inds = np.asarray(comp_edges[condition][\"sink\"].to_list())\n\n    if len(sink_comp_inds) > 0:\n        conds_bp2c = (\n            vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0))(\n                params[\"radius\"][sink_comp_inds],\n                params[\"axial_resistivity\"][sink_comp_inds],\n                params[\"length\"][sink_comp_inds],\n            )\n            / params[\"capacitance\"][sink_comp_inds]\n        )\n    else:\n        conds_bp2c = jnp.asarray([])\n\n    # `compartment-to-branchpoint` (c2bp) axial coupling conductances.\n    condition = comp_edges[\"type\"].isin([3, 4])\n    source_comp_inds = np.asarray(comp_edges[condition][\"source\"].to_list())\n\n    if len(source_comp_inds) > 0:\n        conds_c2bp = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n            params[\"radius\"][source_comp_inds],\n            params[\"axial_resistivity\"][source_comp_inds],\n            params[\"length\"][source_comp_inds],\n        )\n        # For numerical stability. These values are very small, but their scale\n        # does not matter.\n        conds_c2bp *= 1_000\n    else:\n        conds_c2bp = jnp.asarray([])\n\n    # All axial coupling conductances.\n    return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp])\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_children_and_parents","title":"compute_children_and_parents(branch_edges)","text":"

Build indices used during `._init_morph_custom_spsolve().

Source code in jaxley/utils/cell_utils.py
def compute_children_and_parents(\n    branch_edges: pd.DataFrame,\n) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]:\n    \"\"\"Build indices used during `._init_morph_custom_spsolve().\"\"\"\n    par_inds = branch_edges[\"parent_branch_index\"].to_numpy()\n    child_inds = branch_edges[\"child_branch_index\"].to_numpy()\n    child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n    par_inds = np.unique(par_inds)\n    return par_inds, child_inds, child_belongs_to_branchpoint\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_children_indices","title":"compute_children_indices(parents)","text":"

Return all children indices of every branch.

Example:

parents = [-1, 0, 0]\ncompute_children_indices(parents) -> [[1, 2], [], []]\n

Source code in jaxley/utils/cell_utils.py
def compute_children_indices(parents) -> List[jnp.ndarray]:\n    \"\"\"Return all children indices of every branch.\n\n    Example:\n    ```\n    parents = [-1, 0, 0]\n    compute_children_indices(parents) -> [[1, 2], [], []]\n    ```\n    \"\"\"\n    num_branches = len(parents)\n    child_indices = []\n    for b in range(num_branches):\n        child_indices.append(np.where(parents == b)[0])\n    return child_indices\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond","title":"compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2)","text":"

Return the coupling conductance between two compartments.

Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models.

radius: um r_a: ohm cm length_single_compartment: um coupling_conds: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2

Source code in jaxley/utils/cell_utils.py
def compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2):\n    \"\"\"Return the coupling conductance between two compartments.\n\n    Equations taken from `https://en.wikipedia.org/wiki/Compartmental_neuron_models`.\n\n    `radius`: um\n    `r_a`: ohm cm\n    `length_single_compartment`: um\n    `coupling_conds`: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2\n    \"\"\"\n    # Multiply by 10**7 to convert (S / cm / um) -> (mS / cm^2).\n    return rad1 * rad2**2 / (r_a1 * rad2**2 * l1 + r_a2 * rad1**2 * l2) / l1 * 10**7\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond_branchpoint","title":"compute_coupling_cond_branchpoint(rad, r_a, l)","text":"

Return the coupling conductance between one compartment and a comp with l=0.

From https://en.wikipedia.org/wiki/Compartmental_neuron_models

If one compartment has l=0.0 then the equations simplify.

R_long = \\sum_i r_a * L_i/2 / crosssection_i

with crosssection = pi * r**2

For a single compartment with L>0, this turns into: R_long = r_a * L/2 / crosssection

Then, g_long = crosssection * 2 / L / r_a

Then, the effective conductance is g_long / zylinder_area. So: g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L g = r / r_a / L**2

Source code in jaxley/utils/cell_utils.py
def compute_coupling_cond_branchpoint(rad, r_a, l):\n    r\"\"\"Return the coupling conductance between one compartment and a comp with l=0.\n\n    From https://en.wikipedia.org/wiki/Compartmental_neuron_models\n\n    If one compartment has l=0.0 then the equations simplify.\n\n    R_long = \\sum_i r_a * L_i/2 / crosssection_i\n\n    with crosssection = pi * r**2\n\n    For a single compartment with L>0, this turns into:\n    R_long = r_a * L/2 / crosssection\n\n    Then, g_long = crosssection * 2 / L / r_a\n\n    Then, the effective conductance is g_long / zylinder_area. So:\n    g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L\n    g = r / r_a / L**2\n    \"\"\"\n    return rad / r_a / l**2 * 10**7  # Convert (S / cm / um) -> (mS / cm^2)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_impact_on_node","title":"compute_impact_on_node(rad, r_a, l)","text":"

Compute the weight with which a compartment influences its node.

In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0

Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a

This equation can be multiplied by any constant.

Source code in jaxley/utils/cell_utils.py
def compute_impact_on_node(rad, r_a, l):\n    r\"\"\"Compute the weight with which a compartment influences its node.\n\n    In order to satisfy Kirchhoffs current law, the current at a branch point must be\n    proportional to the crosssection of the compartment. We only require proportionality\n    here because the branch point equation reads:\n    `g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0`\n\n    Because R_long = r_a * L/2 / crosssection, we get\n    g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a\n\n    This equation can be multiplied by any constant.\"\"\"\n    return rad**2 / r_a / l\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_morphology_indices_in_levels","title":"compute_morphology_indices_in_levels(num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds)","text":"

Return (row, col) to build the sparse matrix defining the voltage eqs.

This is run at init, not during runtime.

Source code in jaxley/utils/cell_utils.py
def compute_morphology_indices_in_levels(\n    num_branchpoints,\n    child_belongs_to_branchpoint,\n    par_inds,\n    child_inds,\n):\n    \"\"\"Return (row, col) to build the sparse matrix defining the voltage eqs.\n\n    This is run at `init`, not during runtime.\n    \"\"\"\n    branchpoint_inds_parents = jnp.arange(num_branchpoints)\n    branchpoint_inds_children = child_belongs_to_branchpoint\n    branch_inds_parents = par_inds\n    branch_inds_children = child_inds\n\n    children = jnp.stack([branch_inds_children, branchpoint_inds_children])\n    parents = jnp.stack([branch_inds_parents, branchpoint_inds_parents])\n\n    return {\"children\": children.T, \"parents\": parents.T}\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.convert_point_process_to_distributed","title":"convert_point_process_to_distributed(current, radius, length)","text":"

Convert current point process (nA) to distributed current (uA/cm2).

This function gets called for synapses and for external stimuli.

Parameters:

Name Type Description Default current ndarray

Current in nA.

required radius ndarray

Compartment radius in um.

required length ndarray

Compartment length in um.

required Return

Current in uA/cm2.

Source code in jaxley/utils/cell_utils.py
def convert_point_process_to_distributed(\n    current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray\n) -> jnp.ndarray:\n    \"\"\"Convert current point process (nA) to distributed current (uA/cm2).\n\n    This function gets called for synapses and for external stimuli.\n\n    Args:\n        current: Current in `nA`.\n        radius: Compartment radius in `um`.\n        length: Compartment length in `um`.\n\n    Return:\n        Current in `uA/cm2`.\n    \"\"\"\n    area = 2 * pi * radius * length\n    current /= area  # nA / um^2\n    return current * 100_000  # Convert (nA / um^2) to (uA / cm^2)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.equal_segments","title":"equal_segments(branch_property, ncomp_per_branch)","text":"

Generates segments where some property is the same in each segment.

Parameters:

Name Type Description Default branch_property list

List of values of the property in each branch. Should have len(branch_property) == num_branches.

required Source code in jaxley/utils/cell_utils.py
def equal_segments(branch_property: list, ncomp_per_branch: int):\n    \"\"\"Generates segments where some property is the same in each segment.\n\n    Args:\n        branch_property: List of values of the property in each branch. Should have\n            `len(branch_property) == num_branches`.\n    \"\"\"\n    assert isinstance(branch_property, list), \"branch_property must be a list.\"\n    return jnp.asarray([branch_property] * ncomp_per_branch).T\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.get_num_neighbours","title":"get_num_neighbours(num_children, ncomp_per_branch, num_branches)","text":"

Number of neighbours of each compartment.

Source code in jaxley/utils/cell_utils.py
def get_num_neighbours(\n    num_children: jnp.ndarray,\n    ncomp_per_branch: int,\n    num_branches: int,\n):\n    \"\"\"\n    Number of neighbours of each compartment.\n    \"\"\"\n    num_neighbours = 2 * jnp.ones((num_branches * ncomp_per_branch))\n    num_neighbours = num_neighbours.at[ncomp_per_branch - 1].set(1.0)\n    num_neighbours = num_neighbours.at[jnp.arange(num_branches) * ncomp_per_branch].set(\n        num_children + 1.0\n    )\n    return num_neighbours\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.group_and_sum","title":"group_and_sum(values_to_sum, inds_to_group_by, num_branchpoints)","text":"

Group values by whether they have the same integer and sum values within group.

This is used to construct the last diagonals at the branch points.

Written by ChatGPT.

Source code in jaxley/utils/cell_utils.py
def group_and_sum(\n    values_to_sum: jnp.ndarray, inds_to_group_by: jnp.ndarray, num_branchpoints: int\n) -> jnp.ndarray:\n    \"\"\"Group values by whether they have the same integer and sum values within group.\n\n    This is used to construct the last diagonals at the branch points.\n\n    Written by ChatGPT.\n    \"\"\"\n    # Initialize an array to hold the sum of each group\n    group_sums = jnp.zeros(num_branchpoints)\n\n    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n    # `len(inds) == 0` is the case for branches and compartments.\n    if num_branchpoints > 0:\n        group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)\n\n    return group_sums\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.interpolate_xyzr","title":"interpolate_xyzr(loc, coords)","text":"

Perform a linear interpolation between xyz-coordinates.

Parameters:

Name Type Description Default loc float

The location in [0,1] along the branch.

required coords ndarray

Array containing the reconstructed xyzr points of the branch.

required Return

Interpolated xyz coordinate at loc, shape `(3,).

Source code in jaxley/utils/cell_utils.py
def interpolate_xyzr(loc: float, coords: np.ndarray):\n    \"\"\"Perform a linear interpolation between xyz-coordinates.\n\n    Args:\n        loc: The location in [0,1] along the branch.\n        coords: Array containing the reconstructed xyzr points of the branch.\n\n    Return:\n        Interpolated xyz coordinate at `loc`, shape `(3,).\n    \"\"\"\n    dl = np.sqrt(np.sum(np.diff(coords[:, :3], axis=0) ** 2, axis=1))\n    pathlens = np.insert(np.cumsum(dl), 0, 0)  # cummulative length of sections\n    norm_pathlens = pathlens / np.maximum(1e-8, pathlens[-1])  # norm lengths to [0,1].\n\n    return v_interp(loc, norm_pathlens, coords)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.linear_segments","title":"linear_segments(initial_val, endpoint_vals, parents, ncomp_per_branch)","text":"

Generates segments where some property is linearly interpolated.

Parameters:

Name Type Description Default initial_val float

The value at the tip of the soma.

required endpoint_vals list

The value at the endpoints of each branch.

required Source code in jaxley/utils/cell_utils.py
def linear_segments(\n    initial_val: float, endpoint_vals: list, parents: jnp.ndarray, ncomp_per_branch: int\n):\n    \"\"\"Generates segments where some property is linearly interpolated.\n\n    Args:\n        initial_val: The value at the tip of the soma.\n        endpoint_vals: The value at the endpoints of each branch.\n    \"\"\"\n    branch_property = endpoint_vals + [initial_val]\n    num_branches = len(parents)\n    # Compute radiuses by linear interpolation.\n    endpoint_radiuses = jnp.asarray(branch_property)\n\n    def compute_rad(branch_ind, loc):\n        start = endpoint_radiuses[parents[branch_ind]]\n        end = endpoint_radiuses[branch_ind]\n        return (end - start) * loc + start\n\n    branch_inds_of_each_comp = jnp.tile(jnp.arange(num_branches), ncomp_per_branch)\n    locs_of_each_comp = jnp.linspace(1, 0, ncomp_per_branch).repeat(num_branches)\n    rad_of_each_comp = compute_rad(branch_inds_of_each_comp, locs_of_each_comp)\n\n    return jnp.reshape(rad_of_each_comp, (ncomp_per_branch, num_branches)).T\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.loc_of_index","title":"loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch)","text":"

Return location corresponding to global compartment index.

Source code in jaxley/utils/cell_utils.py
def loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch):\n    \"\"\"Return location corresponding to global compartment index.\"\"\"\n    cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n    index = global_comp_index - cumsum_ncomp[global_branch_index]\n    ncomp = ncomp_per_branch[global_branch_index]\n    return (0.5 + index) / ncomp\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.local_index_of_loc","title":"local_index_of_loc(loc, global_branch_ind, ncomp_per_branch)","text":"

Returns the local index of a comp given a loc [0, 1] and the index of a branch.

This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here.

Parameters:

Name Type Description Default branch_ind

Index of the branch.

required loc float

Location (in [0, 1]) along that branch.

required ncomp_per_branch int

Number of segments of each branch.

required

Returns:

Type Description int

The local index of the compartment.

Source code in jaxley/utils/cell_utils.py
def local_index_of_loc(\n    loc: float, global_branch_ind: int, ncomp_per_branch: int\n) -> int:\n    \"\"\"Returns the local index of a comp given a loc [0, 1] and the index of a branch.\n\n    This is used because we specify locations such as synapses as a value between 0 and\n    1. We have to convert this onto a discrete segment here.\n\n    Args:\n        branch_ind: Index of the branch.\n        loc: Location (in [0, 1]) along that branch.\n        ncomp_per_branch: Number of segments of each branch.\n\n    Returns:\n        The local index of the compartment.\n    \"\"\"\n    ncomp = ncomp_per_branch[global_branch_ind]  # only for convenience.\n    possible_locs = np.linspace(0.5 / ncomp, 1 - 0.5 / ncomp, ncomp)\n    ind_along_branch = np.argmin(np.abs(possible_locs - loc))\n    return ind_along_branch\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.merge_cells","title":"merge_cells(cumsum_num_branches, cumsum_num_branchpoints, arrs, exclude_first=True)","text":"

Build full list of which branches are solved in which iteration.

From the branching pattern of single cells, this \u201cmerges\u201d them into a single ordering of branches.

Parameters:

Name Type Description Default cumsum_num_branches List[int]

cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing [0, 10, 25, 30].

required arrs List[List[ndarray]]

A list of a list of arrays that should be merged.

required exclude_first bool

If True, the first element of each list in arrs will remain unchanged. Useful if a -1 (which indicates \u201cno parent\u201d) entry should not be changed.

True

Returns:

Type Description ndarray

A list of arrays which contain the branch indices that are computed at each

ndarray

level (i.e., iteration).

Source code in jaxley/utils/cell_utils.py
def merge_cells(\n    cumsum_num_branches: List[int],\n    cumsum_num_branchpoints: List[int],\n    arrs: List[List[np.ndarray]],\n    exclude_first: bool = True,\n) -> np.ndarray:\n    \"\"\"\n    Build full list of which branches are solved in which iteration.\n\n    From the branching pattern of single cells, this \"merges\" them into a single\n    ordering of branches.\n\n    Args:\n        cumsum_num_branches: cumulative number of branches. E.g., for three cells with\n            10, 15, and 5 branches respectively, this will should be a list containing\n            `[0, 10, 25, 30]`.\n        arrs: A list of a list of arrays that should be merged.\n        exclude_first: If `True`, the first element of each list in `arrs` will remain\n            unchanged. Useful if a `-1` (which indicates \"no parent\") entry should not\n            be changed.\n\n    Returns:\n        A list of arrays which contain the branch indices that are computed at each\n        level (i.e., iteration).\n    \"\"\"\n    ps = []\n    for i, att in enumerate(arrs):\n        p = att\n        if exclude_first:\n            raise NotImplementedError\n            p = [p[0]] + [p_in_level + cumsum_num_branches[i] for p_in_level in p[1:]]\n        else:\n            p = [\n                p_in_level\n                + np.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]])\n                for p_in_level in p\n            ]\n        ps.append(p)\n\n    max_len = max([len(att) for att in arrs])\n    combined_parents_in_level = []\n    for i in range(max_len):\n        current_ps = []\n        for p in ps:\n            if len(p) > i:\n                current_ps.append(p[i])\n        combined_parents_in_level.append(np.concatenate(current_ps))\n\n    return combined_parents_in_level\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.params_to_pstate","title":"params_to_pstate(params, indices_set_by_trainables)","text":"

Make outputs get_parameters() conform with outputs of .data_set().

make_trainable() followed by params=get_parameters() does not return indices because these indices would also be differentiated by jax.grad (as soon as the params are passed to def simulate(params). Therefore, in jx.integrate, we run the function to add indices to the dict. The outputs of params_to_pstate are of the same shape as the outputs of .data_set().

Source code in jaxley/utils/cell_utils.py
def params_to_pstate(\n    params: List[Dict[str, jnp.ndarray]],\n    indices_set_by_trainables: List[jnp.ndarray],\n):\n    \"\"\"Make outputs `get_parameters()` conform with outputs of `.data_set()`.\n\n    `make_trainable()` followed by `params=get_parameters()` does not return indices\n    because these indices would also be differentiated by `jax.grad` (as soon as\n    the `params` are passed to `def simulate(params)`. Therefore, in `jx.integrate`,\n    we run the function to add indices to the dict. The outputs of `params_to_pstate`\n    are of the same shape as the outputs of `.data_set()`.\"\"\"\n    return [\n        {\"key\": list(p.keys())[0], \"val\": list(p.values())[0], \"indices\": i}\n        for p, i in zip(params, indices_set_by_trainables)\n    ]\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.query_channel_states_and_params","title":"query_channel_states_and_params(d, keys, idcs)","text":"

Get dict with subset of keys and values from d.

This is used to restrict a dict where every item contains all states to only the ones that are relevant for the channel. E.g.

states = {'eCa': Array([ 0., 0., nan]}

will be states = {'eCa': Array([ 0., 0.]}

Only loops over necessary keys, as opposed to looping over d.items().

Source code in jaxley/utils/cell_utils.py
def query_channel_states_and_params(d, keys, idcs):\n    \"\"\"Get dict with subset of keys and values from d.\n\n    This is used to restrict a dict where every item contains __all__ states to only\n    the ones that are relevant for the channel. E.g.\n\n    ```states = {'eCa': Array([ 0.,  0., nan]}```\n\n    will be\n    ```states = {'eCa': Array([ 0.,  0.]}```\n\n    Only loops over necessary keys, as opposed to looping over `d.items()`.\"\"\"\n    return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.remap_to_consecutive","title":"remap_to_consecutive(arr)","text":"

Maps an array of integers to an array of consecutive integers.

E.g. [0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]

Source code in jaxley/utils/cell_utils.py
def remap_to_consecutive(arr):\n    \"\"\"Maps an array of integers to an array of consecutive integers.\n\n    E.g. `[0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]`\n    \"\"\"\n    _, inverse_indices = jnp.unique(arr, return_inverse=True)\n    return inverse_indices\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.compute_rotation_matrix","title":"compute_rotation_matrix(axis, angle)","text":"

Return the rotation matrix associated with counterclockwise rotation about the given axis by the given angle.

Can be used to rotate a coordinate vector by multiplying it with the rotation matrix.

Parameters:

Name Type Description Default axis ndarray

The axis of rotation.

required angle float

The angle of rotation in radians.

required

Returns:

Type Description ndarray

A 3x3 rotation matrix.

Source code in jaxley/utils/plot_utils.py
def compute_rotation_matrix(axis: ndarray, angle: float) -> ndarray:\n    \"\"\"\n    Return the rotation matrix associated with counterclockwise rotation about\n    the given axis by the given angle.\n\n    Can be used to rotate a coordinate vector by multiplying it with the rotation\n    matrix.\n\n    Args:\n        axis: The axis of rotation.\n        angle: The angle of rotation in radians.\n\n    Returns:\n        A 3x3 rotation matrix.\n    \"\"\"\n    axis = axis / np.sqrt(np.dot(axis, axis))\n    a = np.cos(angle / 2.0)\n    b, c, d = -axis * np.sin(angle / 2.0)\n    aa, bb, cc, dd = a * a, b * b, c * c, d * d\n    bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d\n    return np.array(\n        [\n            [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],\n            [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],\n            [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],\n        ]\n    )\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_cone_frustum_mesh","title":"create_cone_frustum_mesh(length, radius_bottom, radius_top, bottom_dome=False, top_dome=False, resolution=100)","text":"

Generates mesh points for a cone frustum, with optional domes at either end.

This is used to render the traced morphology in 3D (and to project it to 2D) as part of plot_morph. Sections between two traced coordinates with two different radii can be represented by a cone frustum. Additionally, the ends of the frustum can be capped with hemispheres to ensure that two neighbouring frustums are connected smoothly (like ball joints).

Parameters:

Name Type Description Default length float

The length of the frustum.

required radius_bottom float

The radius of the bottom of the frustum.

required radius_top float

The radius of the top of the frustum.

required bottom_dome bool

If True, a dome is added to the bottom of the frustum. The dome is a hemisphere with radius radius_bottom.

False top_dome bool

If True, a dome is added to the top of the frustum. The dome is a hemisphere with radius radius_top.

False resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description ndarray

An array of mesh points.

Source code in jaxley/utils/plot_utils.py
def create_cone_frustum_mesh(\n    length: float,\n    radius_bottom: float,\n    radius_top: float,\n    bottom_dome: bool = False,\n    top_dome: bool = False,\n    resolution: int = 100,\n) -> ndarray:\n    \"\"\"Generates mesh points for a cone frustum, with optional domes at either end.\n\n    This is used to render the traced morphology in 3D (and to project it to 2D)\n    as part of `plot_morph`. Sections between two traced coordinates with two\n    different radii can be represented by a cone frustum. Additionally, the ends\n    of the frustum can be capped with hemispheres to ensure that two neighbouring\n    frustums are connected smoothly (like ball joints).\n\n    Args:\n        length: The length of the frustum.\n        radius_bottom: The radius of the bottom of the frustum.\n        radius_top: The radius of the top of the frustum.\n        bottom_dome: If True, a dome is added to the bottom of the frustum.\n            The dome is a hemisphere with radius `radius_bottom`.\n        top_dome: If True, a dome is added to the top of the frustum.\n            The dome is a hemisphere with radius `radius_top`.\n        resolution: defines the resolution of the mesh.\n            If too low (typically <10), can result in errors.\n            Useful too have a simpler mesh for plotting.\n\n    Returns:\n        An array of mesh points.\n    \"\"\"\n\n    t = np.linspace(0, 2 * np.pi, resolution)\n\n    # Determine the total height including domes\n    total_height = length\n    total_height += radius_bottom if bottom_dome else 0\n    total_height += radius_top if top_dome else 0\n\n    z = np.linspace(0, total_height, resolution)\n    t_grid, z_coords = np.meshgrid(t, z)\n\n    # Initialize arrays\n    x_coords = np.zeros_like(t_grid)\n    y_coords = np.zeros_like(t_grid)\n    r_coords = np.zeros_like(t_grid)\n\n    # Bottom hemisphere\n    if bottom_dome:\n        dome_mask = z_coords < radius_bottom\n        arg = 1 - z_coords[dome_mask] / radius_bottom\n        arg[np.isclose(arg, 1, atol=1e-6, rtol=1e-6)] = 1\n        arg[np.isclose(arg, -1, atol=1e-6, rtol=1e-6)] = -1\n        phi = np.arccos(1 - z_coords[dome_mask] / radius_bottom)\n        r_coords[dome_mask] = radius_bottom * np.sin(phi)\n        z_coords[dome_mask] = z_coords[dome_mask]\n\n    # Frustum\n    frustum_start = radius_bottom if bottom_dome else 0\n    frustum_end = total_height - (radius_top if top_dome else 0)\n    frustum_mask = (z_coords >= frustum_start) & (z_coords <= frustum_end)\n    z_frustum = z_coords[frustum_mask] - frustum_start\n    r_coords[frustum_mask] = radius_bottom + (radius_top - radius_bottom) * (\n        z_frustum / length\n    )\n\n    # Top hemisphere\n    if top_dome:\n        dome_mask = z_coords > (total_height - radius_top)\n        arg = (z_coords[dome_mask] - (total_height - radius_top)) / radius_top\n        arg[np.isclose(arg, 1, atol=1e-6, rtol=1e-6)] = 1\n        arg[np.isclose(arg, -1, atol=1e-6, rtol=1e-6)] = -1\n        phi = np.arccos(arg)\n        r_coords[dome_mask] = radius_top * np.sin(phi)\n\n    x_coords = r_coords * np.cos(t_grid)\n    y_coords = r_coords * np.sin(t_grid)\n\n    return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_cylinder_mesh","title":"create_cylinder_mesh(length, radius, resolution=100)","text":"

Generates mesh points for a cylinder.

This is used to render cylindrical compartments in 3D (and to project it to 2D) as part of plot_comps.

Parameters:

Name Type Description Default length float

The length of the cylinder.

required radius float

The radius of the cylinder.

required resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description ndarray

An array of mesh points.

Source code in jaxley/utils/plot_utils.py
def create_cylinder_mesh(\n    length: float, radius: float, resolution: int = 100\n) -> ndarray:\n    \"\"\"Generates mesh points for a cylinder.\n\n    This is used to render cylindrical compartments in 3D (and to project it to 2D)\n    as part of `plot_comps`.\n\n    Args:\n        length: The length of the cylinder.\n        radius: The radius of the cylinder.\n        resolution: defines the resolution of the mesh.\n            If too low (typically <10), can result in errors.\n            Useful too have a simpler mesh for plotting.\n\n    Returns:\n        An array of mesh points.\n    \"\"\"\n    # Define cylinder\n    t = np.linspace(0, 2 * np.pi, resolution)\n    z_coords = np.linspace(-length / 2, length / 2, resolution)\n    t_grid, z_coords = np.meshgrid(t, z_coords)\n\n    x_coords = radius * np.cos(t_grid)\n    y_coords = radius * np.sin(t_grid)\n    return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_sphere_mesh","title":"create_sphere_mesh(radius, resolution=100)","text":"

Generates mesh points for a sphere.

This is used to render spherical compartments in 3D (and to project it to 2D) as part of plot_comps.

Parameters:

Name Type Description Default radius float

The radius of the sphere.

required resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description ndarray

An array of mesh points.

Source code in jaxley/utils/plot_utils.py
def create_sphere_mesh(radius: float, resolution: int = 100) -> np.ndarray:\n    \"\"\"Generates mesh points for a sphere.\n\n    This is used to render spherical compartments in 3D (and to project it to 2D)\n    as part of `plot_comps`.\n\n    Args:\n        radius: The radius of the sphere.\n        resolution: defines the resolution of the mesh.\n            If too low (typically <10), can result in errors.\n            Useful too have a simpler mesh for plotting.\n\n    Returns:\n        An array of mesh points.\n    \"\"\"\n    phi = np.linspace(0, np.pi, resolution)\n    theta = np.linspace(0, 2 * np.pi, resolution)\n\n    # Create a 2D meshgrid for phi and theta\n    phi_coords, theta_coords = np.meshgrid(phi, theta)\n\n    # Convert spherical coordinates to Cartesian coordinates\n    x_coords = radius * np.sin(phi_coords) * np.cos(theta_coords)\n    y_coords = radius * np.sin(phi_coords) * np.sin(theta_coords)\n    z_coords = radius * np.cos(phi_coords)\n\n    return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.extract_outline","title":"extract_outline(points)","text":"

Get the outline of a 2D/3D shape.

Extracts the subset of points which form the convex hull, i.e. the outline of the input points.

Parameters:

Name Type Description Default points ndarray

An array of points / corrdinates.

required

Returns:

Type Description ndarray

An array of points which form the convex hull.

Source code in jaxley/utils/plot_utils.py
def extract_outline(points: ndarray) -> ndarray:\n    \"\"\"Get the outline of a 2D/3D shape.\n\n    Extracts the subset of points which form the convex hull, i.e. the outline of\n    the input points.\n\n    Args:\n        points: An array of points / corrdinates.\n\n    Returns:\n        An array of points which form the convex hull.\n    \"\"\"\n    hull = ConvexHull(points)\n    hull_points = points[hull.vertices]\n    return hull_points\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_comps","title":"plot_comps(module_or_view, dims=(0, 1), color='k', ax=None, true_comp_length=True, resolution=100, **kwargs)","text":"

Plot compartmentalized neural morphology.

Plots the projection of the cylindrical compartments.

Parameters:

Name Type Description Default module_or_view Union[Module, View]

The module or view to plot.

required dims Tuple[int]

The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.

(0, 1) color str

The color for all compartments

'k' ax Optional[Axes]

The matplotlib axis to plot on.

None true_comp_length bool

If True, the length of the compartment is used, i.e. the length of the traced neurite. This means for zig-zagging neurites the cylinders will be longer than the straight-line distance between the start and end point of the neurite. This can lead to overlapping and miss-aligned cylinders. Setting this False will use the straight-line distance instead for nicer plots.

True resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100 kwargs

The plot kwargs for plt.fill.

{}

Returns:

Type Description Axes

Plot of the compartmentalized morphology.

Source code in jaxley/utils/plot_utils.py
def plot_comps(\n    module_or_view: Union[\"jx.Module\", \"jx.View\"],\n    dims: Tuple[int] = (0, 1),\n    color: str = \"k\",\n    ax: Optional[Axes] = None,\n    true_comp_length: bool = True,\n    resolution: int = 100,\n    **kwargs,\n) -> Axes:\n    \"\"\"Plot compartmentalized neural morphology.\n\n    Plots the projection of the cylindrical compartments.\n\n    Args:\n        module_or_view: The module or view to plot.\n        dims: The dimensions to plot / to project the cylinder onto,\n            i.e. [0,1] xy-plane or [0,1,2] for 3D.\n        color: The color for all compartments\n        ax: The matplotlib axis to plot on.\n        true_comp_length: If True, the length of the compartment is used, i.e. the\n            length of the traced neurite. This means for zig-zagging neurites the\n            cylinders will be longer than the straight-line distance between the\n            start and end point of the neurite. This can lead to overlapping and\n            miss-aligned cylinders. Setting this False will use the straight-line\n            distance instead for nicer plots.\n        resolution: defines the resolution of the mesh.\n            If too low (typically <10), can result in errors.\n            Useful too have a simpler mesh for plotting.\n        kwargs: The plot kwargs for plt.fill.\n\n    Returns:\n        Plot of the compartmentalized morphology.\n    \"\"\"\n    if ax is None:\n        fig = plt.figure(figsize=(3, 3))\n        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n    assert not np.any(\n        np.isnan(module_or_view.xyzr[0][:, :3])\n    ), \"missing xyz coordinates.\"\n    if \"x\" not in module_or_view.nodes.columns:\n        module_or_view.compute_compartment_centers()\n\n    for idx, xyzr in zip(module_or_view._branches_in_view, module_or_view.xyzr):\n        locs = xyzr[:, :3]\n        if locs.shape[0] == 1:  # assume spherical comp\n            radius = xyzr[:, -1]\n            center = xyzr[0, :3]\n            if len(dims) == 3:\n                xyz = create_sphere_mesh(radius, resolution)\n                ax = plot_mesh(\n                    xyz,\n                    np.array([0, 0, 1]),\n                    center,\n                    np.array(dims),\n                    ax,\n                    color=color,\n                    **kwargs,\n                )\n            else:\n                ax.add_artist(plt.Circle(locs[0, dims], radius, color=color))\n        else:\n            lens = np.sqrt(np.nansum(np.diff(locs, axis=0) ** 2, axis=1))\n            lens = np.cumsum([0] + lens.tolist())\n            comp_ends = v_interp(\n                np.linspace(0, lens[-1], module_or_view.ncomp + 1), lens, locs\n            ).T\n            axes = np.diff(comp_ends, axis=0)\n            cylinder_lens = np.sqrt(np.sum(axes**2, axis=1))\n\n            branch_df = module_or_view.nodes[\n                module_or_view.nodes[\"global_branch_index\"] == idx\n            ]\n            for l, axis, (i, comp) in zip(cylinder_lens, axes, branch_df.iterrows()):\n                center = comp[[\"x\", \"y\", \"z\"]]\n                radius = comp[\"radius\"]\n                length = comp[\"length\"] if true_comp_length else l\n                xyz = create_cylinder_mesh(length, radius, resolution)\n                ax = plot_mesh(\n                    xyz,\n                    axis,\n                    center,\n                    np.array(dims),\n                    ax,\n                    color=color,\n                    **kwargs,\n                )\n    return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_graph","title":"plot_graph(xyzr, dims=(0, 1), color='k', ax=None, type='line', **kwargs)","text":"

Plot morphology.

Parameters:

Name Type Description Default xyzr ndarray

The coordinates of the morphology.

required dims Tuple[int]

Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two or three of them.

(0, 1) color str

The color for all branches.

'k' ax Optional[Axes]

The matplotlib axis to plot on.

None type str

Either line or scatter.

'line' kwargs

The plot kwargs for plt.plot or plt.scatter.

{} Source code in jaxley/utils/plot_utils.py
def plot_graph(\n    xyzr: ndarray,\n    dims: Tuple[int] = (0, 1),\n    color: str = \"k\",\n    ax: Optional[Axes] = None,\n    type: str = \"line\",\n    **kwargs,\n) -> Axes:\n    \"\"\"Plot morphology.\n\n    Args:\n        xyzr: The coordinates of the morphology.\n        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n            two or three of them.\n        color: The color for all branches.\n        ax: The matplotlib axis to plot on.\n        type: Either `line` or `scatter`.\n        kwargs: The plot kwargs for plt.plot or plt.scatter.\n    \"\"\"\n\n    if ax is None:\n        fig = plt.figure(figsize=(3, 3))\n        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n    for coords_of_branch in xyzr:\n        points = coords_of_branch[:, dims].T\n\n        if \"line\" in type.lower():\n            _ = ax.plot(*points, color=color, **kwargs)\n        elif \"scatter\" in type.lower():\n            _ = ax.scatter(*points, color=color, **kwargs)\n        else:\n            raise NotImplementedError\n\n    return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_mesh","title":"plot_mesh(mesh_points, orientation, center, dims, ax=None, **kwargs)","text":"

Plot the 2D projection of a volume mesh on a cardinal plane.

Project the projection of a cylinder that is oriented in 3D space. - Create cylinder mesh - rotate cylinder mesh to orient it lengthwise along a given orientation vector. - move its center - project onto plane - compute outline of projected mesh. - fill area inside the outline

Parameters:

Name Type Description Default mesh_points ndarray

coordinates of the xyz mesh that define the volume

required orientation ndarray

orientation vector. The cylinder will be oriented along this vector.

required center ndarray

The x,y,z coordinates of the center of the cylinder.

required dims Tuple[int]

The dimensions to plot / to project the cylinder onto,

required ax Axes

The matplotlib axis to plot on.

None

Returns:

Type Description Axes

Plot of the cylinder projection.

Source code in jaxley/utils/plot_utils.py
def plot_mesh(\n    mesh_points: ndarray,\n    orientation: ndarray,\n    center: ndarray,\n    dims: Tuple[int],\n    ax: Axes = None,\n    **kwargs,\n) -> Axes:\n    \"\"\"Plot the 2D projection of a volume mesh on a cardinal plane.\n\n    Project the projection of a cylinder that is oriented in 3D space.\n    - Create cylinder mesh\n    - rotate cylinder mesh to orient it lengthwise along a given orientation vector.\n    - move its center\n    - project onto plane\n    - compute outline of projected mesh.\n    - fill area inside the outline\n\n    Args:\n        mesh_points: coordinates of the xyz mesh that define the volume\n        orientation: orientation vector. The cylinder will be oriented along this vector.\n        center: The x,y,z coordinates of the center of the cylinder.\n        dims: The dimensions to plot / to project the cylinder onto,\n        i.e. [0,1] xy-plane or [0,1,2] for 3D.\n        ax: The matplotlib axis to plot on.\n\n    Returns:\n        Plot of the cylinder projection.\n    \"\"\"\n    if ax is None:\n        fig = plt.figure(figsize=(3, 3))\n        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n    # Normalize axis vector\n    orientation = np.array(orientation)\n    orientation = orientation / np.linalg.norm(orientation)\n\n    # Create a rotation matrix to align the cylinder with the given axis\n    z_axis = np.array([0, 0, 1])\n    rotation_axis = np.cross(z_axis, orientation)\n    rotation_angle = np.arccos(np.dot(z_axis, orientation))\n\n    if np.allclose(rotation_axis, 0):\n        rotation_matrix = np.eye(3)\n    else:\n        rotation_matrix = compute_rotation_matrix(rotation_axis, rotation_angle)\n\n    # Rotate mesh\n    x_mesh, y_mesh, z_mesh = mesh_points\n    rotated_mesh_points = np.dot(\n        rotation_matrix,\n        np.array([x_mesh.flatten(), y_mesh.flatten(), z_mesh.flatten()]),\n    )\n    rotated_mesh_points = rotated_mesh_points.reshape(3, -1)\n\n    # project onto plane and move\n    rotated_mesh_points = rotated_mesh_points[dims]\n    rotated_mesh_points += np.array(center)[dims, np.newaxis]\n\n    if len(dims) < 3:\n        # get outline of cylinder mesh\n        mesh_outline = extract_outline(rotated_mesh_points.T).T\n        ax.fill(*mesh_outline.reshape(mesh_outline.shape[0], -1), **kwargs)\n    else:\n        # plot 3d mesh\n        ax.plot_surface(*rotated_mesh_points.reshape(*mesh_points.shape), **kwargs)\n    return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_morph","title":"plot_morph(module_or_view, dims=(0, 1), color='k', ax=None, resolution=100, **kwargs)","text":"

Plot the detailed morphology.

Plots the traced morphology it was traced. That means at every point that was traced a disc of radius r is plotted. The outline of the discs are then connected to form the morphology. This means every trace segement can be represented by a cone frustum. To prevent breaks in the morphology, each segement is connected with a ball joint.

Parameters:

Name Type Description Default module_or_view Union[Module, View]

The module or view to plot.

required dims Tuple[int]

The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.

(0, 1) color str

The color for all branches

'k' ax Optional[Axes]

The matplotlib axis to plot on.

None kwargs

The plot kwargs for plt.fill.

{} resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description Axes

Plot of the detailed morphology.

Source code in jaxley/utils/plot_utils.py
def plot_morph(\n    module_or_view: Union[\"jx.Module\", \"jx.View\"],\n    dims: Tuple[int] = (0, 1),\n    color: str = \"k\",\n    ax: Optional[Axes] = None,\n    resolution: int = 100,\n    **kwargs,\n) -> Axes:\n    \"\"\"Plot the detailed morphology.\n\n    Plots the traced morphology it was traced. That means at every point that was\n    traced a disc of radius `r` is plotted. The outline of the discs are then\n    connected to form the morphology. This means every trace segement can be\n    represented by a cone frustum. To prevent breaks in the morphology, each\n    segement is connected with a ball joint.\n\n    Args:\n        module_or_view: The module or view to plot.\n        dims: The dimensions to plot / to project the cylinder onto,\n            i.e. [0,1] xy-plane or [0,1,2] for 3D.\n        color: The color for all branches\n        ax: The matplotlib axis to plot on.\n        kwargs: The plot kwargs for plt.fill.\n\n        resolution: defines the resolution of the mesh.\n            If too low (typically <10), can result in errors.\n            Useful too have a simpler mesh for plotting.\n\n    Returns:\n        Plot of the detailed morphology.\"\"\"\n    if ax is None:\n        fig = plt.figure(figsize=(3, 3))\n        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n    if len(dims) == 3:\n        warn(\n            \"rendering large morphologies in 3D can take a while. Consider projecting to 2D instead.\"\n        )\n\n    assert not np.any(\n        np.isnan(module_or_view.xyzr[0][:, :3])\n    ), \"missing xyz coordinates.\"\n\n    for xyzr in module_or_view.xyzr:\n        if len(xyzr) > 1:\n            for xyzr1, xyzr2 in zip(xyzr[1:, :], xyzr[:-1, :]):\n                dxyz = xyzr2[:3] - xyzr1[:3]\n                length = np.sqrt(np.sum(dxyz**2))\n                points = create_cone_frustum_mesh(\n                    length,\n                    xyzr1[-1],\n                    xyzr2[-1],\n                    bottom_dome=True,\n                    top_dome=True,\n                    resolution=resolution,\n                )\n                plot_mesh(\n                    points,\n                    dxyz,\n                    xyzr1[:3],\n                    np.array(dims),\n                    color=color,\n                    ax=ax,\n                    **kwargs,\n                )\n        else:\n            points = create_cone_frustum_mesh(\n                0,\n                xyzr[:, -1],\n                xyzr[:, -1],\n                bottom_dome=True,\n                top_dome=True,\n                resolution=resolution,\n            )\n            plot_mesh(\n                points,\n                np.ones(3),\n                xyzr[0, :3],\n                dims=np.array(dims),\n                color=color,\n                ax=ax,\n                **kwargs,\n            )\n\n    return ax\n
"},{"location":"reference/utils/#jaxley.utils.jax_utils.nested_checkpoint_scan","title":"nested_checkpoint_scan(f, init, xs, length=None, *, nested_lengths, scan_fn=jax.lax.scan, checkpoint_fn=jax.checkpoint)","text":"

A version of lax.scan that supports recursive gradient checkpointing.

Code taken from: https://github.com/google/jax/issues/2139

The interface of nested_checkpoint_scan exactly matches lax.scan, except for the required nested_lengths argument.

The key feature of nested_checkpoint_scan is that gradient calculations require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested scans, which it achieves by re-evaluating the forward pass len(nested_lengths) - 1 times.

nested_checkpoint_scan reduces to lax.scan when nested_lengths has a single element.

Parameters:

Name Type Description Default f Callable[[Carry, Dict[str, ndarray]], Tuple[Carry, Output]]

function to scan over.

required init Carry

initial value.

required xs Dict[str, ndarray]

scanned over values.

required length Optional[int]

leading length of all dimensions

None nested_lengths Sequence[int]

required list of lengths to scan over for each level of checkpointing. The product of nested_lengths must match length (if provided) and the size of the leading axis for all arrays in xs.

required scan_fn

function matching the API of lax.scan

scan checkpoint_fn Callable[[Func], Func]

function matching the API of jax.checkpoint.

checkpoint Source code in jaxley/utils/jax_utils.py
def nested_checkpoint_scan(\n    f: Callable[[Carry, Dict[str, jnp.ndarray]], Tuple[Carry, Output]],\n    init: Carry,\n    xs: Dict[str, jnp.ndarray],\n    length: Optional[int] = None,\n    *,\n    nested_lengths: Sequence[int],\n    scan_fn=jax.lax.scan,\n    checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,\n):\n    \"\"\"A version of lax.scan that supports recursive gradient checkpointing.\n\n    Code taken from: https://github.com/google/jax/issues/2139\n\n    The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for\n    the required `nested_lengths` argument.\n\n    The key feature of `nested_checkpoint_scan` is that gradient calculations\n    require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested\n    scans, which it achieves by re-evaluating the forward pass\n    `len(nested_lengths) - 1` times.\n\n    `nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a\n    single element.\n\n    Args:\n        f: function to scan over.\n        init: initial value.\n        xs: scanned over values.\n        length: leading length of all dimensions\n        nested_lengths: required list of lengths to scan over for each level of\n            checkpointing. The product of nested_lengths must match length (if\n            provided) and the size of the leading axis for all arrays in ``xs``.\n        scan_fn: function matching the API of lax.scan\n        checkpoint_fn: function matching the API of jax.checkpoint.\n    \"\"\"\n    if length is not None and length != math.prod(nested_lengths):\n        raise ValueError(f\"inconsistent {length=} and {nested_lengths=}\")\n\n    def nested_reshape(x):\n        x = jnp.asarray(x)\n        new_shape = tuple(nested_lengths) + x.shape[1:]\n        return x.reshape(new_shape)\n\n    sub_xs = jax.tree_util.tree_map(nested_reshape, xs)\n    return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn, checkpoint_fn)\n
"},{"location":"reference/utils/#jaxley.utils.syn_utils.gather_synapes","title":"gather_synapes(number_of_compartments, post_syn_comp_inds, current_each_synapse_voltage_term, current_each_synapse_constant_term)","text":"

Compute current at the post synapse.

All this does it that it sums the synaptic currents that come into a particular compartment. It returns an array of as many elements as there are compartments.

Source code in jaxley/utils/syn_utils.py
def gather_synapes(\n    number_of_compartments: jnp.ndarray,\n    post_syn_comp_inds: np.ndarray,\n    current_each_synapse_voltage_term: jnp.ndarray,\n    current_each_synapse_constant_term: jnp.ndarray,\n) -> Tuple[jnp.ndarray, jnp.ndarray]:\n    \"\"\"Compute current at the post synapse.\n\n    All this does it that it sums the synaptic currents that come into a particular\n    compartment. It returns an array of as many elements as there are compartments.\n    \"\"\"\n    incoming_currents_voltages = jnp.zeros((number_of_compartments,))\n    incoming_currents_contant = jnp.zeros((number_of_compartments,))\n\n    dnums = ScatterDimensionNumbers(\n        update_window_dims=(),\n        inserted_window_dims=(0,),\n        scatter_dims_to_operand_dims=(0,),\n    )\n    incoming_currents_voltages = scatter_add(\n        incoming_currents_voltages,\n        post_syn_comp_inds[:, None],\n        current_each_synapse_voltage_term,\n        dnums,\n    )\n    incoming_currents_contant = scatter_add(\n        incoming_currents_contant,\n        post_syn_comp_inds[:, None],\n        current_each_synapse_constant_term,\n        dnums,\n    )\n    return incoming_currents_voltages, incoming_currents_contant\n
"},{"location":"tutorial/00_jaxley_api/","title":"Key concepts in Jaxley","text":"

In this tutorial, we will introduce you to the basic concepts of Jaxley. You will learn about:

  • Modules (e.g., Cell, Network,\u2026)
    • nodes
    • edges
  • Views
    • Groups
  • Channels
  • Synapses

Here is a code snippet which you will learn to understand in this tutorial:

import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n\n# Assembling different Modules into a Network\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=1)\ncell = jx.Cell(branch, parents=[-1, 0, 0])\nnet = jx.Network([cell]*3)\n\n# Navigating and inspecting the Modules using Views\ncell0 = net.cell(0)\ncell0.nodes\n\n# How to group together parts of Modules\nnet.cell(1).add_to_group(\"cell1\")\n\n# inserting channels in the membrane\nwith net.cell(0) as cell0:\n    cell0.insert(Na())\n    cell0.insert(K())\n\n# connecting two cells using a Synapse\npre_comp = cell0.branch(1).comp(0)\npost_comp = net.cell1.branch(0).comp(0)\n\nconnect(pre_comp, post_comp)\n

First, we import the relevant libraries:

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\nimport matplotlib.pyplot as plt\nimport numpy as np\n
"},{"location":"tutorial/00_jaxley_api/#modules","title":"Modules","text":"

In Jaxley, we heavily rely on the concept of Modules to build biophyiscal models of neural systems at various scales. Jaxley implements four types of Modules: - Compartment - Branch - Cell - Network

Modules can be connected together to build increasingly detailed and complex models. Compartment -> Branch -> Cell -> Network.

Compartments are the atoms of biophysical models in Jaxley. All mechanisms and synaptic connections live on the level of Compartments and can already be simulated using jx.integrate on their own. Everything you do in Jaxley starts with a Compartment.

comp = jx.Compartment() # single compartment model.\n

Mutliple Compartments can be connected together to form longer, linear cables, which we call Branches and are equivalent to sections in NEURON.

ncomp = 4\nbranch = jx.Branch([comp] * ncomp)\n

In order to construct cell morphologies in Jaxley, multiple Branches can to be connected together as a Cell:

# -1 indicates that the first branch has no parent branch.\n# The other two branches both have the 0-eth branch as their parent.\nparents = [-1, 0, 0]\ncell = jx.Cell([branch] * len(parents), parents)\n

Finally, several Cells can be grouped together to form a Network, which can than be connected together using Synpases.

ncells = 2\nnet = jx.Network([cell]*ncells)\n\nnet.shape # shows you the num_cells, num_branches, num_comps\n
(2, 6, 24)\n

Every module tracks information about its current state and parameters in two Dataframes called nodes and edges. nodes contains all the information that we associate with compartments in the model (each row corresponds to one compartment) and edges tracks all the information relevant to synapses.

This means that you can easily keep track of the current state of your Module and how it changes at all times.

net.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 0 2 11 0 12 1 0 0 10.0 1.0 5000.0 1.0 -70.0 1 3 12 0 13 1 0 1 10.0 1.0 5000.0 1.0 -70.0 1 3 13 0 14 1 0 2 10.0 1.0 5000.0 1.0 -70.0 1 3 14 0 15 1 0 3 10.0 1.0 5000.0 1.0 -70.0 1 3 15 0 16 1 1 0 10.0 1.0 5000.0 1.0 -70.0 1 4 16 0 17 1 1 1 10.0 1.0 5000.0 1.0 -70.0 1 4 17 0 18 1 1 2 10.0 1.0 5000.0 1.0 -70.0 1 4 18 0 19 1 1 3 10.0 1.0 5000.0 1.0 -70.0 1 4 19 0 20 1 2 0 10.0 1.0 5000.0 1.0 -70.0 1 5 20 0 21 1 2 1 10.0 1.0 5000.0 1.0 -70.0 1 5 21 0 22 1 2 2 10.0 1.0 5000.0 1.0 -70.0 1 5 22 0 23 1 2 3 10.0 1.0 5000.0 1.0 -70.0 1 5 23 0
net.edges.head() # this is currently empty since we have not made any connections yet\n
global_edge_index pre_global_comp_index post_global_comp_index pre_locs post_locs type type_ind"},{"location":"tutorial/00_jaxley_api/#views","title":"Views","text":"

Since these Modules can become very complex, Jaxley utilizes so called Views to make working with Modules easy and intuitive.

The simplest way to navigate Modules is by navigating them via the hierachy that we introduced above. A View is what you get when you index into the module. For example, for a Network:

net.cell(0)\n
View with 0 different channels. Use `.nodes` for details.\n

Views behave very similarly to Modules, i.e. the cell(0) (the 0th cell of the network) behaves like the cell we instantiated earlier. As such, cell(0) also has a nodes attribute, which keeps track of it\u2019s part of the network:

net.cell(0).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 0 2 11 0

Let\u2019s use Views to visualize only parts of the Network. Before we do that, we create x, y, and z coordinates for the Network:

# Compute xyz coordinates of the cells.\nnet.compute_xyz()\n\n# Move cells (since they are placed on top of each other by default).\nnet.cell(0).move(y=30)\n

We can now visualize the entire net (i.e., the entire Module) with the .vis() method\u2026

# We can use the vis function to visualize Modules.\nfig, ax = plt.subplots(1, 1, figsize=(3,3))\nnet.vis(ax=ax)\n
<Axes: >\n

\u2026but we can also create a View to visualize only parts of the net:

# ... and Views\nfig, ax = plt.subplots(1,1, figsize=(3,3))\nnet.cell(0).vis(ax=ax, color=\"blue\") # View of the 0th cell of the network\nnet.cell(1).vis(ax=ax, color=\"red\") # View of the 1st cell of the network\n\nnet.cell(0).branch(0).vis(ax=ax, color=\"green\") # View of the 1st branch of the 0th cell of the network\nnet.cell(1).branch(1).comp(1).vis(ax=ax, color=\"black\", type=\"line\") # View of the 0th comp of the 1st branch of the 0th cell of the network\n
<Axes: >\n

"},{"location":"tutorial/00_jaxley_api/#how-to-create-views","title":"How to create Views","text":"

Above, we used net.cell(0) to generate a View of the 0-eth cell. Jaxley supports many ways of performing such indexing:

# several types of indices are supported (lists, ranges, ...)\nnet.cell([0,1]).branch(\"all\").comp(0)  # View of all 0th comps of all branches of cell 0 and 1\n\nbranch.loc(0.1)  # Equivalent to `NEURON`s `loc`. Assumes branches are continous from 0-1.\n\nnet[0,0,0]  # Modules/Views can also be lazily indexed\n\ncell0 = net.cell(0)  # Views can be assigned to variables and only track the parts of the Module they belong to\ncell0.branch(1).comp(0)  # Views can be continuely indexed\n
View with 0 different channels. Use `.nodes` for details.\n
cell0.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 0 2 11 0
net.shape\n
(2, 6, 24)\n

Note: In case you need even more flexibility in how you select parts of a Module, Jaxley provides a select method, to give full control over the exact parts of the nodes and edges that are part of a View. On examples of how this can be used, see the tutorial on advanced indexing.

You can also iterate over networks, cells, and branches:

# We set the radiuses to random values...\nradiuses = np.random.rand((24))\nnet.set(\"radius\", radiuses)\n\n# ...and then we set the length to 100.0 um if the radius is >0.5.\nfor cell in net:\n    for branch in cell:\n        for comp in branch:\n            if comp.nodes.iloc[0][\"radius\"] > 0.5:\n                comp.set(\"length\", 100.0)\n\n# Show the first five compartments:\nnet.nodes[[\"radius\", \"length\"]][:5]\n
radius length 0 0.537066 100.0 1 0.050138 10.0 2 0.913129 100.0 3 0.874596 100.0 4 0.048903 10.0

Finally, you can also use Views in a context manager:

with net.cell(0).branch(0) as branch0:\n    branch0.set(\"radius\", 2.0)\n    branch0.set(\"length\", 2.5)\n\n# Show the first five compartments.\nnet.nodes[[\"radius\", \"length\"]][:5]\n
radius length 0 2.000000 2.5 1 2.000000 2.5 2 2.000000 2.5 3 2.000000 2.5 4 0.048903 10.0"},{"location":"tutorial/00_jaxley_api/#channels","title":"Channels","text":"

The Modules that we have created above will not do anything interesting, since by default Jaxley initializes them without any mechanisms in the membrane. To change this, we have to insert channels into the membrane. For this purpose Jaxley implements Channels that can be inserted into any compartment using the insert method of a Module or a View:

# insert a Leak channel into all compartments in the Module.\nnet.insert(Leak())\nnet.nodes.head() # Channel parameters are now also added to `nodes`.\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param Leak Leak_gLeak Leak_eLeak 0 0 0 0 2.5 2.000000 5000.0 1.0 -70.0 0 0 0 0 True 0.0001 -70.0 1 0 0 1 2.5 2.000000 5000.0 1.0 -70.0 0 0 1 0 True 0.0001 -70.0 2 0 0 2 2.5 2.000000 5000.0 1.0 -70.0 0 0 2 0 True 0.0001 -70.0 3 0 0 3 2.5 2.000000 5000.0 1.0 -70.0 0 0 3 0 True 0.0001 -70.0 4 0 1 0 10.0 0.048903 5000.0 1.0 -70.0 0 1 4 0 True 0.0001 -70.0

This is also were Views come in handy, as it allows to easily target the insertion of channels to specific compartments.

# inserting several channels into parts of the network\nwith net.cell(0) as cell0:\n    cell0.insert(Na())\n    cell0.insert(K())\n\n# # The above is equivalent to:\n# net.cell(0).insert(Na())\n# net.cell(0).insert(K())\n\n# K and Na channels were only insert into cell 0\nnet.cell(\"all\").branch(0).comp(0).nodes[[\"global_cell_index\", \"Na\", \"K\", \"Leak\"]]\n
global_cell_index Na K Leak 0 0 True True True 12 1 False False True"},{"location":"tutorial/00_jaxley_api/#synapses","title":"Synapses","text":"

To connect different cells together, Jaxley implements a connect method, that can be used to couple 2 compartments together using a Synapse. Synapses in Jaxley work only on the compartment level, that means to be able to connect two cells, you need to specify the exact compartments on a given cell to make the connections between. Below is an example of this:

# connecting two cells using a Synapse\npre_comp = cell0.branch(1).comp(0)\npost_comp = net.cell(1).branch(0).comp(0)\n\nconnect(pre_comp, post_comp, IonotropicSynapse())\n\nnet.edges\n
global_edge_index pre_global_comp_index post_global_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 4 12 IonotropicSynapse 0 0.125 0.125 0.0001 0.0 0.025 0.2 0

As you can see above, now the edges dataframe is also updated with the information of the newly added synapse.

Congrats! You should now have an intuitive understand of how to use Jaxley\u2019s API to construct, navigate and manipulate neuron models.

"},{"location":"tutorial/01_morph_neurons/","title":"Basics of Jaxley","text":"

In this tutorial, you will learn how to:

  • build your first morphologically detailed cell or read it from SWC
  • stimulate the cell
  • record from the cell
  • visualize cells
  • run your first simulation

Here is a code snippet which you will learn to understand in this tutorial:

import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nimport matplotlib.pyplot as plt\n\n\n# Build the cell.\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])\n\n# Insert channels.\ncell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n\n# Change parameters.\ncell.set(\"axial_resistivity\", 200.0)\n\n# Visualize the morphology.\ncell.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\ncell.vis(ax=ax)\n\n# Stimulate.\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=10.0)\ncell.branch(0).loc(0.0).stimulate(current)\n\n# Record.\ncell.branch(0).loc(0.0).record(\"v\")\n\n# Simulate and plot.\nv = jx.integrate(cell, delta_t=0.025)\nplt.plot(v.T)\n

First, we import the relevant libraries:

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n

We will now build our first cell in Jaxley. You have two options to do this: you can either build a cell bottom-up by defining the morphology yourselve, or you can load cells from SWC files.

"},{"location":"tutorial/01_morph_neurons/#define-the-cell-from-scratch","title":"Define the cell from scratch","text":"

To define a cell from scratch you first have to define a single compartment and then assemble those compartments into a branch:

comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\n

Next, we can assemble branches into a cell. To do so, we have to define for each branch what its parent branch is. A -1 entry means that this branch does not have a parent.

parents = jnp.asarray([-1, 0, 0, 1, 1])\ncell = jx.Cell(branch, parents=parents)\n

To learn more about Compartments, Branches, and Cells, see this tutorial.

"},{"location":"tutorial/01_morph_neurons/#read-the-cell-from-an-swc-file","title":"Read the cell from an SWC file","text":"

Alternatively, you could also load cells from SWC with

cell = jx.read_swc(fname, ncomp=4)

Details on handling SWC files can be found in this tutorial.

"},{"location":"tutorial/01_morph_neurons/#visualize-the-cells","title":"Visualize the cells","text":"

Cells can be visualized as follows:

cell.compute_xyz()  # Only needed for visualization.\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, color=\"k\")\n

"},{"location":"tutorial/01_morph_neurons/#insert-mechanisms","title":"Insert mechanisms","text":"

Currently, the cell does not contain any kind of ion channel (not even a leak). We can fix this by inserting a leak channel into the entire cell, and by inserting sodium and potassium into the zero-eth branch.

cell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n

Once the cell is created, we can inspect its .nodes attribute which lists all properties of the cell:

cell.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index ... Na Na_gNa eNa vt Na_m Na_h K K_gK eK K_n 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 ... True 0.05 50.0 -60.0 0.2 0.2 True 0.005 -90.0 0.2 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 ... True 0.05 50.0 -60.0 0.2 0.2 True 0.005 -90.0 0.2 2 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 3 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 4 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 5 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 6 0 3 0 10.0 1.0 5000.0 1.0 -70.0 0 3 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 7 0 3 1 10.0 1.0 5000.0 1.0 -70.0 0 3 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 8 0 4 0 10.0 1.0 5000.0 1.0 -70.0 0 4 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 9 0 4 1 10.0 1.0 5000.0 1.0 -70.0 0 4 ... False NaN NaN NaN NaN NaN False NaN NaN NaN

10 rows \u00d7 25 columns

Note that Jaxley uses the same units as the NEURON simulator, which are listed here.

You can also inspect just parts of the cell, for example its 1st branch:

cell.branch(1).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v Leak Leak_gLeak ... Na_m Na_h K K_gK eK K_n global_cell_index global_branch_index global_comp_index controlled_by_param 2 0 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 2 1 3 0 0 1 10.0 1.0 5000.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 3 1

2 rows \u00d7 25 columns

The easiest way to know which branch is the 1st branch (or, e.g., the zero-eth compartment of the 1st branch) is to plot it in a different color:

fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, color=\"k\")\n_ = cell.branch(1).vis(ax=ax, color=\"r\")\n_ = cell.branch(1).comp(1).vis(ax=ax, color=\"b\")\n

More background and features on indexing as cell.branch(0) is in this tutorial.

"},{"location":"tutorial/01_morph_neurons/#change-parameters-of-the-cell","title":"Change parameters of the cell","text":"

You can change properties of the cell with the .set() method:

cell.branch(1).set(\"axial_resistivity\", 200.0)\n

And we can again inspect the .nodes to make sure that the axial resistivity indeed changed:

cell.branch(1).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v Leak Leak_gLeak ... Na_m Na_h K K_gK eK K_n global_cell_index global_branch_index global_comp_index controlled_by_param 2 0 0 0 10.0 1.0 200.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 2 1 3 0 0 1 10.0 1.0 200.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 3 1

2 rows \u00d7 25 columns

In a similar way, you can modify channel properties or initial states (units are again here):

cell.branch(0).set(\"K_gK\", 0.01)  # modify potassium conductance.\ncell.set(\"v\", -65.0)  # modify initial voltage.\n
"},{"location":"tutorial/01_morph_neurons/#stimulate-the-cell","title":"Stimulate the cell","text":"

We next stimulate one of the compartments with a step current. For this, we first define the step current (units are again here):

dt = 0.025\nt_max = 10.0\ntime_vec = np.arange(0, t_max+dt, dt)\ncurrent = jx.step_current(i_delay=1.0, i_dur=2.0, i_amp=0.08, delta_t=dt, t_max=t_max)\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = plt.plot(time_vec, current)\n

We then stimulate one of the compartments of the cell with this step current:

cell.delete_stimuli()\ncell.branch(0).loc(0.0).stimulate(current)\n
Added 1 external_states. See `.externals` for details.\n
"},{"location":"tutorial/01_morph_neurons/#define-recordings","title":"Define recordings","text":"

Next, you have to define where to record the voltage. In this case, we will record the voltage at two locations:

cell.delete_recordings()\ncell.branch(0).loc(0.0).record(\"v\")\ncell.branch(3).loc(1.0).record(\"v\")\n
Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n

We can again visualize these locations to understand where we inserted recordings:

fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax)\n_ = cell.branch(0).loc(0.0).vis(ax=ax, color=\"b\")\n_ = cell.branch(3).loc(1.0).vis(ax=ax, color=\"g\")\n

"},{"location":"tutorial/01_morph_neurons/#simulate-the-cell-response","title":"Simulate the cell response","text":"

Having set up the cell, inserted stimuli and recordings, we are now ready to run a simulation with jx.integrate:

voltages = jx.integrate(cell, delta_t=dt)\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (2, 402)\n

The jx.integrate function returns an array of shape (num_recordings, num_timepoints). In our case, we inserted 2 recordings and we simulated for 10ms at a 0.025 time step, which leads to 402 time steps.

We can now visualize the voltage response:

fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(voltages[0], c=\"b\")\n_ = ax.plot(voltages[1], c=\"orange\")\n

At the location of the first recording (in blue) the cell spiked, whereas at the second recording, it did not. This makes sense because we only inserted sodium and potassium channels into the first branch, but not in the entire cell.

Congrats! You have just run your first morphologically detailed neuron simulation in Jaxley. We suggest to continue by learning how to build networks. If you are only interested in single cell simulations, you can directly jump to learning how to speed up simulations. If you want to simulate detailed morphologies from SWC files, checkout our tutorial on working with detailed morphologies.

"},{"location":"tutorial/02_small_network/","title":"Network simulations in Jaxley","text":"

In this tutorial, you will learn how to:

  • connect neurons into a network
  • visualize networks
  • use the .edges attribute to inspect and change synaptic parameters

Here is a code snippet which you will learn to understand in this tutorial:

import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\n\n\n# Define a network. `cell` is defined as in previous tutorial.\nnet = jx.Network([cell for _ in range(11)])\n\n# Define synapses.\nfully_connect(\n    net.cell(range(10)),\n    net.cell(10),\n    IonotropicSynapse(),\n)\n\n# Change synaptic parameters.\nnet.select(edges=[0, 1]).set(\"IonotropicSynapse_gS\", 0.1)  # nS\n\n# Visualize the network.\nnet.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\nnet.vis(ax=ax, detail=\"full\", layers=[10, 1])  # or `detail=\"point\"`.\n

In the previous tutorial, you learned how to build single cells with morphological detail, how to insert stimuli and recordings, and how to run a first simulation. In this tutorial, we will define networks of multiple cells and connect them with synapses. Let\u2019s get started:

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect, connect\n
"},{"location":"tutorial/02_small_network/#define-the-network","title":"Define the network","text":"

First, we define a cell as you saw in the previous tutorial.

comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n

We can assemble multiple cells into a network by using jx.Network, which takes a list of jx.Cells. Here, we assemble 11 cells into a network:

num_cells = 11\nnet = jx.Network([cell for _ in range(num_cells)])\n

At this point, we can already visualize this network:

net.compute_xyz()\nnet.rotate(180)\nnet.arrange_in_layers(layers=[10, 1], within_layer_offset=150, between_layer_offset=200)\n\nfig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\")\n

Note: you can use move_to to have more control over the location of cells, e.g.: network.cell(i).move_to(x=0, y=200).

As you can see, the neurons are not connected yet. Let\u2019s fix this by connecting neurons with synapses. We will build a network consisting of two layers: 10 neurons in the input layer and 1 neuron in the output layer.

We can use Jaxley\u2019s fully_connect method to connect these layers:

pre = net.cell(range(10))\npost = net.cell(10)\nfully_connect(pre, post, IonotropicSynapse())\n

Let\u2019s visualize this again:

fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\")\n

As you can see, the full_connect method inserted one synapse (in blue) from every neuron in the first layer to the output neuron. The fully_connect method builds this synapse from the zero-eth compartment and zero-eth branch of the presynaptic neuron onto a random branch of the postsynaptic neuron. If you want more control over the pre- and post-synaptic branches, you can use the connect method:

pre = net.cell(0).branch(5).loc(1.0)\npost = net.cell(10).branch(0).loc(0.0)\nconnect(pre, post, IonotropicSynapse())\n
fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\")\n

"},{"location":"tutorial/02_small_network/#inspecting-and-changing-synaptic-parameters","title":"Inspecting and changing synaptic parameters","text":"

You can inspect synaptic parameters via the .edges attribute:

net.edges\n
global_edge_index pre_global_comp_index post_global_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 0 287 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 1 1 28 289 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 2 2 56 289 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 3 3 84 301 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 4 4 112 281 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 5 5 140 295 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 6 6 168 289 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 7 7 196 290 IonotropicSynapse 0 0.125 0.625 0.0001 0.0 0.025 0.2 0 8 8 224 303 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 9 9 252 280 IonotropicSynapse 0 0.125 0.125 0.0001 0.0 0.025 0.2 0 10 10 23 280 IonotropicSynapse 0 0.875 0.125 0.0001 0.0 0.025 0.2 0

To modify a parameter of all synapses you can again use .set():

net.set(\"IonotropicSynapse_gS\", 0.0003)  # nS\n

To modify individual syanptic parameters, use the .select() method. Below, we change the values of the first two synapses:

net.select(edges=[0, 1]).set(\"IonotropicSynapse_gS\", 0.0004)  # nS\n

For more details on how to flexibly set synaptic parameters (e.g., by cell type, or by pre-synaptic cell index,\u2026), see this tutorial.

"},{"location":"tutorial/02_small_network/#stimulating-recording-and-simulating-the-network","title":"Stimulating, recording, and simulating the network","text":"

We will now set up a simulation of the network. This works exactly as it does for single neurons:

# Stimulus.\ni_delay = 3.0  # ms\ni_amp = 0.05  # nA\ni_dur = 2.0  # ms\n\n# Duration and step size.\ndt = 0.025  # ms\nt_max = 50.0  # ms\n
time_vec = jnp.arange(0.0, t_max + dt, dt)\n

As a simple example, we insert sodium, potassium, and leak into every compartment of every cell of the network.

net.insert(Na())\nnet.insert(K())\nnet.insert(Leak())\n

We stimulate every neuron in the input layer and record the voltage from the output neuron:

current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)\nnet.delete_stimuli()\nfor stim_ind in range(10):\n    net.cell(stim_ind).branch(0).loc(0.0).stimulate(current)\n\nnet.delete_recordings()\nnet.cell(10).branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n

Finally, we can again run the network simulation and plot the result:

s = jx.integrate(net, delta_t=dt)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T)\n

That\u2019s it! You now know how to simulate networks of morphologically detailed neurons. We recommend that you now have a look at how you can speed up your simulation. To learn more about handling synaptic parameters, we recommend to check out this tutorial.

"},{"location":"tutorial/04_jit_and_vmap/","title":"Speeding up simulations","text":"

In this tutorial, you will learn how to:

  • make parameter sweeps in Jaxley
  • use jit to compile your simulations and make them faster
  • use vmap to parallelize simulations on GPUs

Here is a code snippet which you will learn to understand in this tutorial:

from jax import jit, vmap\n\n\ncell = ...  # See tutorial on Basics of Jaxley.\n\ndef simulate(params):\n    param_state = None\n    param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n    param_state = cell.data_set(\"K_gK\", params[1], param_state)\n    return jx.integrate(cell, param_state=param_state, delta_t=0.025)\n\n# Define 100 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(100, 2))\n\n# Fast for-loops with jit compilation.\njitted_simulate = jit(simulate)\nvoltages = [jitted_simulate(params) for params in all_params]\n\n# Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate, in_axes=(0,))\nvoltages = vmapped_simulate(all_params)\n

In the previous tutorials, you learned how to build single cells or networks and how to change their parameters. In this tutorial, you will learn how to speed up such simulations by many orders of magnitude. This can be achieved in to ways:

  • by using JIT compilation
  • by using GPU parallelization

Let\u2019s get started!

"},{"location":"tutorial/04_jit_and_vmap/#using-gpu-or-cpu","title":"Using GPU or CPU","text":"

In Jaxley you can set whether you want to use gpu or cpu with the following lines at the beginning of your script:

from jax import config\nconfig.update(\"jax_platform_name\", \"cpu\")\n

JAX (and Jaxley) also allow to choose between float32 and float64. Especially on GPUs, float32 will be faster, but we have experienced stability issues when simulating morphologically detailed neurons with float32.

config.update(\"jax_enable_x64\", True)  # Set to false to use `float32`.\n

Next, we will import relevant libraries:

import matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit, vmap\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\n
"},{"location":"tutorial/04_jit_and_vmap/#building-the-cell-or-network","title":"Building the cell or network","text":"

We first build a cell (or network) in the same way as we showed in the previous tutorials:

dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n\ncell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n\ncell.delete_stimuli()\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=dt, t_max=t_max)\ncell.branch(0).loc(0.0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
"},{"location":"tutorial/04_jit_and_vmap/#parameter-sweeps","title":"Parameter sweeps","text":"

Assume you want to run the same cell with many different values for the sodium and potassium conductance, for example for genetic algorithms or for parameter sweeps. To do this efficiently in Jaxley, you have to use the data_set() method (in combination with jit and vmap, as shown later):

def simulate(params):\n    param_state = None\n    param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n    param_state = cell.data_set(\"K_gK\", params[1], param_state)\n    return jx.integrate(cell, param_state=param_state, delta_t=dt)\n

The .data_set() method takes three arguments:

1) the name of the parameter you want to set. Jaxley allows to set the following parameters: \u201cradius\u201d, \u201clength\u201d, \u201caxial_resistivity\u201d, as well as all parameters of channels and synapses. 2) the value of the parameter. 3) a param_state which is initialized as None and is modified by .data_set(). This has to be passed to jx.integrate().

Having done this, the simplest (but least efficient) way to perform the parameter sweep is to run a for-loop over many parameter sets:

# Define 5 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(5, 2))\n\nvoltages = jnp.asarray([simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (5, 1, 402)\n

The resulting voltages have shape (num_simulations, num_recordings, num_timesteps).

"},{"location":"tutorial/04_jit_and_vmap/#stimulus-sweeps","title":"Stimulus sweeps","text":"

In addition to running sweeps across multiple parameters, you can also run sweeeps across multiple stimuli (e.g. step current stimuli of different amplitudes. You can achieve this with the data_stimulate() method:

def simulate(i_amp):\n    current = jx.step_current(1.0, 1.0, i_amp, 0.025, 10.0)\n\n    data_stimuli = None\n    data_stimuli = cell.branch(0).comp(0).data_stimulate(current, data_stimuli)\n    return jx.integrate(cell, data_stimuli=data_stimuli)\n

"},{"location":"tutorial/04_jit_and_vmap/#speeding-up-for-loops-via-jit-compilation","title":"Speeding up for loops via jit compilation","text":"

We can speed up such parameter sweeps (or stimulus sweeps) with jit compilation. jit compilation will compile the simulation when it is run for the first time, such that every other simulation will be must faster. This can be achieved by defining a new function which uses JAX\u2019s jit():

jitted_simulate = jit(simulate)\n
# First run, will be slow.\nvoltages = jitted_simulate(all_params[0])\n
# More runs, will be much faster.\nvoltages = jnp.asarray([jitted_simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (5, 1, 402)\n

jit compilation can be up to 10k times faster, especially for small simulations with few compartments. For very large models, the gain obtained with jit will be much smaller (jit may even provide no speed up at all).

"},{"location":"tutorial/04_jit_and_vmap/#speeding-up-with-gpu-parallelization-via-vmap","title":"Speeding up with GPU parallelization via vmap","text":"

Another way to speed up parameter sweeps is with GPU parallelization. Parallelization in Jaxley can be achieved by using vmap of JAX. To do this, we first create a new function that handles multiple parameter sets directly:

# Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate)\n

We can then run this method on all parameter sets (all_params.shape == (100, 2)), and Jaxley will automatically parallelize across them. Of course, you will only get a speed-up if you have a GPU available and you specified gpu as device in the beginning of this tutorial.

voltages = vmapped_simulate(all_params)\n

GPU parallelization with vmap can give a large speed-up, which can easily be 2-3 orders of magnitude.

"},{"location":"tutorial/04_jit_and_vmap/#combining-jit-and-vmap","title":"Combining jit and vmap","text":"

Finally, you can also combine using jit and vmap. For example, you can run multiple batches of many parallel simulations. Each batch can be parallelized with vmap and simulating each batch can be compiled with jit:

jitted_vmapped_simulate = jit(vmap(simulate))\n
for batch in range(10):\n    all_params = jnp.asarray(np.random.rand(5, 2))\n    voltages_batch = jitted_vmapped_simulate(all_params)\n

That\u2019s all you have to know about jit and vmap! If you have worked through this and the previous tutorials, you should be ready to set up your first network simulations.

"},{"location":"tutorial/04_jit_and_vmap/#next-steps","title":"Next steps","text":"

If you want to learn more, we recommend you to read the tutorial on building channel and synapse models.

Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.

Finally, if you want to learn more about JAX, check out their tutorial on jit or their tutorial on vmap.

"},{"location":"tutorial/05_channel_and_synapse_models/","title":"Building ion channel models","text":"

In this tutorial, you will learn how to:

  • define your own ion channel models beyond the preconfigured channels in Jaxley

This tutorial assumes that you have already learned how to build basic simulations.

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\n

First, we define a cell as you saw in the previous tutorial:

comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n

You have also already learned how to insert preconfigured channels into Jaxley models:

cell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n

In this tutorial, we will show you how to build your own channel and synapse models.

"},{"location":"tutorial/05_channel_and_synapse_models/#your-own-channel","title":"Your own channel","text":"

Below is how you can define your own channel. We will go into detail about individual parts of the code in the next couple of cells.

import jax.numpy as jnp\nfrom jaxley.channels import Channel\nfrom jaxley.solver_gate import solve_gate_exponential\n\n\ndef exp_update_alpha(x, y):\n    return x / (jnp.exp(x / y) - 1.0)\n\nclass Potassium(Channel):\n    \"\"\"Potassium channel.\"\"\"\n\n    def __init__(self, name = None):\n        self.current_is_in_mA_per_cm2 = True\n        super().__init__(name)\n        self.channel_params = {\"gK_new\": 1e-4}\n        self.channel_states = {\"n_new\": 0.0}\n        self.current_name = \"i_K\"\n\n    def update_states(self, states, dt, v, params):\n        \"\"\"Update state.\"\"\"\n        ns = states[\"n_new\"]\n        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n        beta = 0.125 * jnp.exp(-(v + 65) / 80)\n        new_n = solve_gate_exponential(ns, dt, alpha, beta)\n        return {\"n_new\": new_n}\n\n    def compute_current(self, states, v, params):\n        \"\"\"Return current.\"\"\"\n        ns = states[\"n_new\"]\n        kd_conds = params[\"gK_new\"] * ns**4  # S/cm^2\n\n        e_kd = -77.0        \n        return kd_conds * (v - e_kd)\n\n    def init_state(self, states, v, params, delta_t):\n        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n        beta = 0.125 * jnp.exp(-(v + 65) / 80)\n        return {\"n_new\": alpha / (alpha + beta)}\n

Let\u2019s look at each part of this in detail.

The below is simply a helper function for the solver of the gate variables:

def exp_update_alpha(x, y):\n    return x / (jnp.exp(x / y) - 1.0)\n

Next, we define our channel as a class. It should inherit from the Channel class and define channel_params, channel_states, and current_name. You also need to set self.current_is_in_mA_per_cm2=True as the first line on your __init__() method. This is to acknowledge that your current is returned in mA/cm2 (not in uA/cm2, as would have been required in Jaxley versions 0.4.0 or older).

class Potassium(Channel):\n    \"\"\"Potassium channel.\"\"\"\n\n    def __init__(self, name=None):\n        self.current_is_in_mA_per_cm2 = True\n        super().__init__(name)\n        self.channel_params = {\"gK_new\": 1e-4}\n        self.channel_states = {\"n_new\": 0.0}\n        self.current_name = \"i_K\"\n

Next, we have the update_states() method, which updates the gating variables:

    def update_states(self, states, dt, v, params):\n

Every channel you define must have an update_states() method which takes exactly these five arguments (self, states, dt, v, params). The inputs states to the update_states method is a dictionary which contains all states that are updated (including states of other channels). v is a jnp.ndarray which contains the voltage of a single compartment (shape ()). Let\u2019s get the state of the potassium channel which we are building here:

ns = states[\"n_new\"]\n

Next, we update the state of the channel. In this example, we do this with exponential Euler, but you can implement any solver yourself:

alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\nbeta = 0.125 * jnp.exp(-(v + 65) / 80)\nnew_n = solve_gate_exponential(ns, dt, alpha, beta)\nreturn {\"n_new\": new_n}\n

A channel also needs a compute_current() method which returns the current throught the channel:

    def compute_current(self, states, v, params):\n        ns = states[\"n_new\"]\n\n        # Multiply with 1000 to convert Siemens to milli Siemens.\n        kd_conds = params[\"gK_new\"] * ns**4  # S/cm^2\n\n        e_kd = -77.0        \n        current = kd_conds * (v - e_kd)\n        return current\n

Finally, the init_state() method can be implemented optionally. It can be used to automatically compute the initial state based on the voltage when cell.init_states() is run.

Alright, done! We can now insert this channel into any jx.Module such as our cell:

cell.insert(Potassium())\n
cell.delete_stimuli()\ncurrent = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0)\ncell.branch(0).comp(0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).comp(0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
s = jx.integrate(cell)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n

"},{"location":"tutorial/05_channel_and_synapse_models/#your-own-synapse","title":"Your own synapse","text":"

The parts below assume that you have already learned how to build network simulations in Jaxley.

Note that again, a synapse needs to have the two functions update_states and compute_current with all input arguments shown below.

The below is an example of how to define your own synapse model in Jaxley:

import jax.numpy as jnp\nfrom jaxley.synapses.synapse import Synapse\n\n\nclass TestSynapse(Synapse):\n    \"\"\"\n    Compute syanptic current and update syanpse state.\n    \"\"\"\n    def __init__(self, name = None):\n        super().__init__(name)\n        self.synapse_params = {\"gChol\": 0.001, \"eChol\": 0.0}\n        self.synapse_states = {\"s_chol\": 0.1}\n\n    def update_states(self, states, delta_t, pre_voltage, post_voltage, params):\n        \"\"\"Return updated synapse state and current.\"\"\"\n        s_inf = 1.0 / (1.0 + jnp.exp((-35.0 - pre_voltage) / 10.0))\n        exp_term = jnp.exp(-delta_t)\n        new_s = states[\"s_chol\"] * exp_term + s_inf * (1.0 - exp_term)\n        return {\"s_chol\": new_s}\n\n    def compute_current(self, states, pre_voltage, post_voltage, params):\n        g_syn = params[\"gChol\"] * states[\"s_chol\"]\n        return g_syn * (post_voltage - params[\"eChol\"])\n

As you can see above, synapses follow closely how channels are defined. The main difference is that the compute_current method takes two voltages: the pre-synaptic voltage (a jnp.ndarray of shape ()) and the post-synaptic voltage (a jnp.ndarray of shape ()).

net = jx.Network([cell for _ in range(3)])\n
from jaxley.connect import connect\n\npre = net.cell(0).branch(0).loc(0.0)\npost = net.cell(1).branch(0).loc(0.0)\nconnect(pre, post, TestSynapse())\n
net.cell(0).branch(0).loc(0.0).stimulate(jx.step_current(1.0, 2.0, 0.1, 0.025, 10.0))\nfor i in range(3):\n    net.cell(i).branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
s = jx.integrate(net)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n

That\u2019s it! You are now ready to build your own custom simulations and equip them with channel and synapse models!

This tutorial does not have an immediate follow-up tutorial. If you have not done so already, you can check out our tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.

"},{"location":"tutorial/06_groups/","title":"Defining groups","text":"

In this tutorial, you will learn how to:

  • define groups (aka sectionlists) to simplify iteractions with Jaxley

Here is a code snippet which you will learn to understand in this tutorial:

from jax import jit, vmap\n\n\nnet = ...  # See tutorial on Basics of Jaxley.\n\nnet.cell(0).add_to_group(\"fast_spiking\")\nnet.cell(1).add_to_group(\"slow_spiking\")\n\ndef simulate(params):\n    param_state = None\n    param_state = net.fast_spiking.data_set(\"HH_gNa\", params[0], param_state)\n    param_state = net.slow_spiking.data_set(\"HH_gNa\", params[1], param_state)\n    return jx.integrate(net, param_state=param_state)\n\n# Define sodium for fast and slow spiking neurons.\nparams = jnp.asarray([1.0, 0.1])\n\n# Run simulation.\nvoltages = simulate(params)\n

In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport time\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n

First, we define a network as you saw in the previous tutorial:

comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1])\nnetwork = jx.Network([cell for _ in range(3)])\n\npre = network.cell([0, 1])\npost = network.cell([2])\nfully_connect(pre, post, IonotropicSynapse())\n\nnetwork.insert(Na())\nnetwork.insert(K())\nnetwork.insert(Leak())\n
"},{"location":"tutorial/06_groups/#group-apical-dendrites","title":"Group: apical dendrites","text":"

Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:

for cell_ind in range(3):\n    network.cell(cell_ind).branch(1).add_to_group(\"apical\")\n    network.cell(cell_ind).branch(3).add_to_group(\"apical\")\n

After this, we can access network.apical as we previously accesses anything else:

network.apical.set(\"radius\", 0.3)\n
network.apical.view\n
View with 3 different channels. Use `.nodes` for details.\n
"},{"location":"tutorial/06_groups/#group-fast-spiking","title":"Group: fast spiking","text":"

Similarly, you could define a group of fast-spiking cells. Assume that the first and second cell are fast-spiking:

network.cell(0).add_to_group(\"fast_spiking\")\nnetwork.cell(1).add_to_group(\"fast_spiking\")\n
network.fast_spiking.set(\"Na_gNa\", 0.4)\n
network.fast_spiking.view\n
View with 3 different channels. Use `.nodes` for details.\n
"},{"location":"tutorial/06_groups/#groups-from-swc-files","title":"Groups from SWC files","text":"

If you are reading .swc morphologigies, you can automatically assign groups with

jx.read_swc(file_name, nseg=n, assign_groups=True).\n
After that, you can directly use cell.soma, cell.apical, cell.basal, or cell.axon.

"},{"location":"tutorial/06_groups/#how-groups-are-interpreted-by-make_trainable","title":"How groups are interpreted by .make_trainable()","text":"

If you make a parameter of a group trainable, then it will be treated as a single shared parameter for a given property:

network.fast_spiking.make_trainable(\"Na_gNa\")\n
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n

As such, get_parameters() returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:

network.get_parameters()\n
[{'Na_gNa': Array([0.4], dtype=float64)}]\n

If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1]):

network.cell([0,1]).make_trainable(\"axial_resistivity\")\n
Number of newly added trainable parameters: 2. Total number of trainable parameters: 3\n
network.get_parameters()\n
[{'Na_gNa': Array([0.4], dtype=float64)},\n {'axial_resistivity': Array([5000., 5000.], dtype=float64)}]\n

This generated two parameters for the axial resistivitiy, each corresponding to one cell.

"},{"location":"tutorial/06_groups/#summary","title":"Summary","text":"

Groups allow you to organize your simulation in a more intuitive way, and they allow to perform parameter sharing with make_trainable().

"},{"location":"tutorial/07_gradient_descent/","title":"Training biophysical models","text":"

In this tutorial, you will learn how to train biophysical models in Jaxley. This includes the following:

  • compute the gradient with respect to parameters
  • use parameter transformations
  • use multi-level checkpointing
  • define optimizers
  • write dataloaders and parallelize across data

Here is a code snippet which you will learn to understand in this tutorial:

from jax import jit, vmap, value_and_grad\nimport jaxley as jx\nimport jaxley.optimize.transforms as jt\n\nnet = ...  # See tutorial on the basics of `Jaxley`.\n\n# Define which parameters to train.\nnet.cell(\"all\").make_trainable(\"HH_gNa\")\nnet.IonotropicSynapse.make_trainable(\"IonotropicSynapse_gS\")\nparameters = net.get_parameters()\n\n# Define parameter transform and apply it to the parameters.\ntransform = jx.ParamTransform([\n    {\"IonotropicSynapse_gS\": jt.SigmoidTransform(0.0, 1.0)},\n    {\"HH_gNa\":jt.SigmoidTransform(0.0, 1, 0)}\n])\n\nopt_params = transform.inverse(parameters)\n\n# Define simulation and batch it across stimuli.\ndef simulate(params, datapoint):\n    current = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amps=datapoint, dt=0.025, t_max=5.0)\n    data_stimuli = net.cell(0).branch(0).comp(0).data_stimulate(current, None)\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_inds=[20, 20], delta_t=0.025)\n\nbatch_simulate = vmap(simulate, in_axes=(None, 0))\n\n# Define loss function and its gradient.\ndef loss_fn(opt_params, datapoints, label):\n    params = transform.forward(opt_params)\n    voltages = batch_simulate(params, datapoints)\n    return jnp.abs(jnp.mean(voltages) - label)\n\ngrad_fn = jit(value_and_grad(loss_fn, argnums=0))\n\n# Define data and dataloader.\ndata = jnp.asarray(np.random.randn(100, 3))\ndataloader = Dataset.from_tensor_slices((inputs, labels))\ndataloader = dataloader.shuffle(dataloader.cardinality()).batch(4)\n\n# Define the optimizer.\noptimizer = optax.Adam(lr=0.01)\nopt_state = optimizer.init_state(opt_params)\n\nfor epoch in range(10):\n    for batch in dataloader:\n        stimuli = batch[0].numpy()\n        labels = batch[1].numpy()\n        loss, gradient = grad_fn(opt_params, stimuli, labels)\n\n        # Optimizer step.\n        updates, opt_state = optimizer.update(gradient, opt_state)\n        opt_params = optax.apply_updates(opt_params, updates)\n

from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, vmap, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Leak\nfrom jaxley.synapses import TanhRateSynapse\nfrom jaxley.connect import fully_connect\n

First, we define a network as you saw in the previous tutorial:

_ = np.random.seed(0)  # For synaptic locations.\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0])\nnet = jx.Network([cell for _ in range(3)])\n\npre = net.cell([0, 1])\npost = net.cell([2])\nfully_connect(pre, post, TanhRateSynapse())\n\n# Change some default values of the tanh synapse.\nnet.TanhRateSynapse.set(\"TanhRateSynapse_x_offset\", -60.0)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_gS\", 1e-3)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_slope\", 0.1)\n\nnet.insert(Leak())\n

This network consists of three neurons arranged in two layers:

net.compute_xyz()\nnet.rotate(180)\nnet.arrange_in_layers(layers=[2, 1], within_layer_offset=100.0, between_layer_offset=100.0)\nfig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = net.vis(ax=ax, detail=\"full\")\n

We consider the last neuron as the output neuron and record the voltage from there:

net.delete_recordings()\nnet.cell(0).branch(0).loc(0.0).record()\nnet.cell(1).branch(0).loc(0.0).record()\nnet.cell(2).branch(0).loc(0.0).record()\n
Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
"},{"location":"tutorial/07_gradient_descent/#defining-a-dataset","title":"Defining a dataset","text":"

We will train this biophysical network on a classification task. The inputs will be values and the label is binary:

inputs = jnp.asarray(np.random.rand(100, 2))\nlabels = jnp.asarray((inputs[:, 0] + inputs[:, 1]) > 1.0)\n
fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(inputs[labels, 0], inputs[labels, 1])\n_ = ax.scatter(inputs[~labels, 0], inputs[~labels, 1])\n

labels = labels.astype(float)\n
"},{"location":"tutorial/07_gradient_descent/#defining-trainable-parameters","title":"Defining trainable parameters","text":"
net.delete_trainables()\n

This follows the same API as .set() seen in the previous tutorial. If you want to use a single parameter for all radiuses in the entire network, do:

net.make_trainable(\"radius\")\n
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n

We can also define parameters for individual compartments. To do this, use the \"all\" key. The following defines a separate parameter the sodium conductance for every compartment in the entire network:

net.cell(\"all\").branch(\"all\").loc(\"all\").make_trainable(\"Leak_gLeak\")\n
Number of newly added trainable parameters: 18. Total number of trainable parameters: 19\n
"},{"location":"tutorial/07_gradient_descent/#making-synaptic-parameters-trainable","title":"Making synaptic parameters trainable","text":"

Synaptic parameters can be made trainable in the exact same way. To use a single parameter for all syanptic conductances in the entire network, do

net.TanhRateSynapse.make_trainable(\"TanhRateSynapse_gS\")\n

Here, we use a different syanptic conductance for all syanpses. This can be done as follows:

net.TanhRateSynapse.edge(\"all\").make_trainable(\"TanhRateSynapse_gS\")\n
Number of newly added trainable parameters: 2. Total number of trainable parameters: 21\n
"},{"location":"tutorial/07_gradient_descent/#running-the-simulation","title":"Running the simulation","text":"

Once all parameters are defined, you have to use .get_parameters() to obtain all trainable parameters. This is also the time to check how many trainable parameters your network has:

params = net.get_parameters()\n

You can now run the simulation with the trainable parameters by passing them to the jx.integrate function.

s = jx.integrate(net, params=params, t_max=10.0)\n
"},{"location":"tutorial/07_gradient_descent/#stimulating-the-network","title":"Stimulating the network","text":"

The network above does not yet get any stimuli. We will use the 2D inputs from the dataset to stimulate the two input neurons. The amplitude of the step current corresponds to the input value. Below is the simulator that defines this:

def simulate(params, inputs):\n    currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10, delta_t=0.025, t_max=10.0)\n\n    data_stimuli = None\n    data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n    data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli, delta_t=0.025)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n

We can also inspect some traces:

traces = batched_simulate(params, inputs[:4])\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(traces[:, 2, :].T)\n

"},{"location":"tutorial/07_gradient_descent/#defining-a-loss-function","title":"Defining a loss function","text":"

Let us define a loss function to be optimized:

def loss(params, inputs, labels):\n    traces = batched_simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[:, 2], axis=1)  # Use the average over time of the output neuron (2) as prediction.\n    prediction = (prediction + 72.0) / 5  # Such that the prediction is roughly in [0, 1].\n    losses = jnp.abs(prediction - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n

And we can use JAX\u2019s inbuilt functions to take the gradient through the entire ODE:

jitted_grad = jit(value_and_grad(loss, argnums=0))\n
value, gradient = jitted_grad(params, inputs[:4], labels[:4])\n
"},{"location":"tutorial/07_gradient_descent/#defining-parameter-transformations","title":"Defining parameter transformations","text":"

Before training, however, we will enforce for all parameters to be within a prespecified range (such that, e.g., conductances can not become negative)

import jaxley.optimize.transforms as jt\n
# Define a function to create appropriate transforms for each parameter\ndef create_transform(name):\n    if name == \"axial_resistivity\":\n        # Must be positive; apply Softplus and scale to match initialization\n        return jt.ChainTransform([jt.SoftplusTransform(0), jt.AffineTransform(5000, 0)])\n    elif name == \"length\":\n        # Apply Softplus and affine transform for the 'length' parameter\n        return jt.ChainTransform([jt.SoftplusTransform(0), jt.AffineTransform(10, 0)])\n    else:\n        # Default to a Softplus transform for other parameters\n        return jt.SoftplusTransform(0)\n\n# Apply the transforms to the parameters\ntransforms = [{k: create_transform(k) for k in param} for param in params]\ntf = jt.ParamTransform(transforms)\n
transform = jx.ParamTransform([{\"radius\": jt.SigmoidTransform(0.1, 5.0)},\n                               {\"Leak_gLeak\":jt.SigmoidTransform(1e-5, 1e-3)},\n                               {\"TanhRateSynapse_gS\" : jt.SigmoidTransform(1e-5, 1e-2)}])\n

With these modify the loss function acocrdingly:

def loss(opt_params, inputs, labels):\n    transform.forward(opt_params)\n\n    traces = batched_simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[:, 2], axis=1)  # Use the average over time of the output neuron (2) as prediction.\n    prediction = (prediction + 72.0)  # Such that the prediction is around 0.\n    losses = jnp.abs(prediction - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n
"},{"location":"tutorial/07_gradient_descent/#using-checkpointing","title":"Using checkpointing","text":"

Checkpointing allows to vastly reduce the memory requirements of training biophysical models (see also JAX\u2019s full tutorial on checkpointing).

t_max = 5.0\ndt = 0.025\n\nlevels = 2\ntime_points = t_max // dt + 2\ncheckpoints = [int(np.ceil(time_points**(1/levels))) for _ in range(levels)]\n

To enable checkpointing, we have to modify the simulate function appropriately and use

jx.integrate(..., checkpoint_inds=checkpoints)\n
as done below:

def simulate(params, inputs):\n    currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10.0, delta_t=dt, t_max=t_max)\n\n    data_stimuli = None\n    data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n    data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n    return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_lengths=checkpoints)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n\n\ndef predict(params, inputs):\n    traces = simulate(params, inputs)  # Shape `(batchsize, num_recordings, timepoints)`.\n    prediction = jnp.mean(traces[2])  # Use the average over time of the output neuron (2) as prediction.\n    return prediction + 72.0  # Such that the prediction is around 0.\n\nbatched_predict = vmap(predict, in_axes=(None, 0))\n\n\ndef loss(opt_params, inputs, labels):\n    params = transform.forward(opt_params)\n\n    predictions = batched_predict(params, inputs)\n    losses = jnp.abs(predictions - labels)  # Mean absolute error loss.\n    return jnp.mean(losses)  # Average across the batch.\n\njitted_grad = jit(value_and_grad(loss, argnums=0))\n
"},{"location":"tutorial/07_gradient_descent/#training","title":"Training","text":"

We will use the ADAM optimizer from the optax library to optimize the free parameters (you have to install the package with pip install optax first):

import optax\n
opt_params = transform.inverse(params)\noptimizer = optax.adam(learning_rate=0.01)\nopt_state = optimizer.init(opt_params)\n
"},{"location":"tutorial/07_gradient_descent/#writing-a-dataloader","title":"Writing a dataloader","text":"

Below, we just write our own (very simple) dataloader. Alternatively, you could use the dataloader from any deep learning library such as pytorch or tensorflow:

class Dataset:\n    def __init__(self, inputs: np.ndarray, labels: np.ndarray):\n        \"\"\"Simple Dataloader.\n\n        Args:\n            inputs: Array of shape (num_samples, num_dim)\n            labels: Array of shape (num_samples,)\n        \"\"\"\n        assert len(inputs) == len(labels), \"Inputs and labels must have same length\"\n        self.inputs = inputs\n        self.labels = labels\n        self.num_samples = len(inputs)\n        self._rng_state = None\n        self.batch_size = 1\n\n    def shuffle(self, seed=None):\n        \"\"\"Shuffle the dataset in-place\"\"\"\n        self._rng_state = np.random.get_state()[1][0] if seed is None else seed\n        np.random.seed(self._rng_state)\n        indices = np.random.permutation(self.num_samples)\n        self.inputs = self.inputs[indices]\n        self.labels = self.labels[indices]\n        return self\n\n    def batch(self, batch_size):\n        \"\"\"Create batches of the data\"\"\"\n        self.batch_size = batch_size\n        return self\n\n    def __iter__(self):\n        self.shuffle(seed=self._rng_state)\n        for start in range(0, self.num_samples, self.batch_size):\n            end = min(start + self.batch_size, self.num_samples)\n            yield self.inputs[start:end], self.labels[start:end]\n        self._rng_state += 1\n
"},{"location":"tutorial/07_gradient_descent/#training-loop","title":"Training loop","text":"
batch_size = 4\ndataloader = Dataset(inputs, labels)\ndataloader = dataloader.shuffle(seed=0).batch(batch_size)\n\nfor epoch in range(10):\n    epoch_loss = 0.0\n\n    for batch_ind, batch in enumerate(dataloader):\n        current_batch, label_batch = batch\n        loss_val, gradient = jitted_grad(opt_params, current_batch, label_batch)\n        updates, opt_state = optimizer.update(gradient, opt_state)\n        opt_params = optax.apply_updates(opt_params, updates)\n        epoch_loss += loss_val\n\n    print(f\"epoch {epoch}, loss {epoch_loss}\")\n\nfinal_params = transform.forward(opt_params)\n
epoch 0, loss 25.033223182772293\nepoch 1, loss 21.00894915349165\nepoch 2, loss 15.092242959956026\nepoch 3, loss 9.061544660383163\nepoch 4, loss 6.925509860325612\nepoch 5, loss 6.273630037897756\nepoch 6, loss 6.1757316054693145\nepoch 7, loss 6.135132525725265\nepoch 8, loss 6.145608619185389\nepoch 9, loss 6.135660902068834\n
ntest = 32\npredictions = batched_predict(final_params, inputs[:ntest])\n
fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(labels[:ntest], predictions)\n_ = ax.set_xlabel(\"Label\")\n_ = ax.set_ylabel(\"Prediction\")\n

Indeed, the loss goes down and the network successfully classifies the patterns.

"},{"location":"tutorial/07_gradient_descent/#summary","title":"Summary","text":"

Puh, this was a pretty dense tutorial with a lot of material. You should have learned how to:

  • compute the gradient with respect to parameters
  • use parameter transformations
  • use multi-level checkpointing
  • define optimizers
  • write dataloaders and parallelize across data

This was the last \u201cbasic\u201d tutorial of the Jaxley toolbox. If you want to learn more, check out our Advanced Tutorials. If anything is still unclear please create a discussion. If you find any bugs, please open an issue. Happy coding!

"},{"location":"tutorial/08_importing_morphologies/","title":"Working with morphologies","text":"

In this tutorial, you will learn how to:

  • Load morphologies and make them compatible with Jaxley
  • Use the visualization features
  • Assemble a small network of morphologically accurate cells.

Here is a code snippet which you will learn to understand in this tutorial:

import jaxley as jx\n\ncell = jx.read_swc(\"my_cell.swc\", ncomp=4)\ncell.branch(2).set_ncomp(2)  # Modify the number of compartments of a branch.\n

To work with more complicated morphologies, Jaxley supports importing morphological reconstructions via .swc files. .swc is currently the only supported format. Other formats like .asc need to be converted to .swc first, for example using the BlueBrain\u2019s morph-tool. For more information on the exact specifications of .swc see here.

import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nimport matplotlib.pyplot as plt\n

To work with .swc files, Jaxley implements a custom .swc reader. The reader traces the morphology and identifies all uninterrupted sections. These uninterrupted sections are called branches in Jaxley. Each branch is then further partitioned into compartments.

To demonstrate this, let\u2019s import an example morphology of a Layer 5 pyramidal cell and visualize it.

# import swc file into jx.Cell object\nfname = \"data/morph.swc\"\ncell = jx.read_swc(fname, ncomp=8)  # Use eight compartments per branch.\n\n# print shape (num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
(157, 1256)\n
local_comp_index global_comp_index local_branch_index global_branch_index local_cell_index global_cell_index 0 0 0 0 0 0 0 1 1 1 0 0 0 0 2 2 2 0 0 0 0 3 3 3 0 0 0 0 4 4 4 0 0 0 0 ... ... ... ... ... ... ... 1251 3 1251 156 156 0 0 1252 4 1252 156 156 0 0 1253 5 1253 156 156 0 0 1254 6 1254 156 156 0 0 1255 7 1255 156 156 0 0

1256 rows \u00d7 6 columns

As we can see, this yields a morphology that is approximated by 1256 compartments. Depending on the amount of detail that you need, you can also change the number of compartments in each branch:

cell = jx.read_swc(fname, ncomp=2)\n\n# print shape (num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
(157, 314)\n
local_comp_index global_comp_index local_branch_index global_branch_index local_cell_index global_cell_index 0 0 0 0 0 0 0 1 1 1 0 0 0 0 2 0 2 1 1 0 0 3 1 3 1 1 0 0 4 0 4 2 2 0 0 ... ... ... ... ... ... ... 309 1 309 154 154 0 0 310 0 310 155 155 0 0 311 1 311 155 155 0 0 312 0 312 156 156 0 0 313 1 313 156 156 0 0

314 rows \u00d7 6 columns

The above assigns the same number of compartments to every branch. To use a different number of compartments in individual branches, you can use .set_ncomp():

cell.branch(1).set_ncomp(4)\n

As you can see below, branch 0 has two compartments (because this is what was passed to jx.read_swc(..., ncomp=2)), but branch 1 has four compartments:

cell.branch([0, 1]).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 0.050000 8.119000 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 0.050000 8.119000 5000.0 1.0 -70.0 0 0 1 0 2 0 1 0 3.120779 7.806172 5000.0 1.0 -70.0 0 1 2 1 3 0 1 1 3.120779 7.111231 5000.0 1.0 -70.0 0 1 3 1 4 0 1 2 3.120779 5.652394 5000.0 1.0 -70.0 0 1 4 1 5 0 1 3 3.120779 3.869247 5000.0 1.0 -70.0 0 1 5 1

Once imported the compartmentalized morphology can be viewed using vis.

# visualize the cell\ncell.vis()\nplt.axis(\"off\")\nplt.title(\"L5PC\")\nplt.show()\n

vis can be called on any jx.Module and every View of the module. This means we can also for example use vis to highlight each branch. This can be done by iterating over each branch index and calling cell.branch(i).vis(). Within the loop.

fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n# define colorwheel with 10 colors\ncolors = plt.cm.tab10.colors\nfor i, branch in enumerate(cell.branches):\n    branch.vis(ax=ax, color=colors[i % 10])\nplt.axis(\"off\")\nplt.title(\"Branches\")\nplt.show()\n

While we only use two compartments to approximate each branch in this example, we can see the morphology is still plotted in great detail. This is because we always plot the full .swc reconstruction irrespective of the number of compartments used. The morphology lives seperately in the cell.xyzr attribute in a per branch fashion.

In addition to plotting the full morphology of the cell using points vis(type=\"scatter\") or lines vis(type=\"line\"), Jaxley also supports plotting a detailed morphological vis(type=\"morph\") or approximate compartmental reconstruction vis(type=\"comp\") that correctly considers the thickness of the neurite. Note that \"comp\" plots the lengths of each compartment which is equal to the length of the traced neurite. While neurites can be zigzaggy, the compartments that approximate them are straight lines. This can lead to miss-aligment of the compartment ends. For details see the documentation of vis.

The morphologies can either be projected onto 2D or also rendered in 3D.

# visualize the cell\nfig, ax = plt.subplots(1, 4, figsize=(10, 3), layout=\"constrained\", sharex=True, sharey=True)\ncell.vis(ax=ax[0], type=\"morph\", dims=[0,1])\ncell.vis(ax=ax[1], type=\"comp\", dims=[0,1])\ncell.vis(ax=ax[2], type=\"scatter\", dims=[0,1], s=1)\ncell.vis(ax=ax[3], type=\"line\", dims=[0,1])\nfig.suptitle(\"Comparison of plot types\")\nplt.show()\n

# set to interactive mode\n# %matplotlib notebook\n
# plot in 3D\nfig = plt.figure()\nax = fig.add_subplot(111, projection='3d')\ncell.vis(ax=ax, type=\"line\", dims=[2,0,1])\nax.view_init(elev=20, azim=5)\nplt.show()\n

Since Jaxley supports grouping different branches or compartments together, we can also use the id labels provided by the .swc file to assign group labels to the jx.Cell object.

print(list(cell.groups.keys()))\n\nfig, ax = plt.subplots(1, 1, figsize=(5, 5))\ncolors = plt.cm.tab10.colors\ncell.basal.vis(ax=ax, color=colors[2])\ncell.soma.vis(ax=ax, color=colors[1])\ncell.apical.vis(ax=ax, color=colors[0])\nplt.axis(\"off\")\nplt.title(\"Groups\")\nplt.show()\n
['soma', 'basal', 'apical', 'custom']\n

To build a network of morphologically detailed cells, we can now connect several reconstructed cells together and also visualize the network. However, since all cells are going to have the same center, Jaxley will naively plot all of them on top of each other. To seperate out the cells, we therefore have to move them to a new location first.

net = jx.Network([cell]*5)\njx.connect(net[0,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[4,0,0], IonotropicSynapse())\n\njx.connect(net[1,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[4,0,0], IonotropicSynapse())\n\nnet.rotate(-90)\n\nnet.cell(0).move(0, 300)\nnet.cell(1).move(0, 500)\n\nnet.cell(2).move(900, 200)\nnet.cell(3).move(900, 400)\nnet.cell(4).move(900, 600)\n\nnet.vis()\nplt.axis(\"off\")\nplt.show()\n

Congrats! You have now learned how to vizualize and build networks out of very complex morphologies. To simulate this network, you can follow the steps in the tutorial on how to build a network.

"},{"location":"tutorial/09_advanced_indexing/","title":"Customizing synaptic parameters","text":"

In this tutorial, you will learn how to:

  • use the select() method to fully customize network simulations with Jaxley.
  • use the copy_node_property_to_edges() method to flexibly modify synapses.

Here is a code snippet which you will learn to understand in this tutorial:

net = ...  # See tutorial on Basics of Jaxley.\n\n# Set synaptic conductance of the synapse with index 0 and 1.\nnet.select(edges=[0, 1]).set(\"Ionotropic_gS\", 0.1)\n\n# Set synaptic conductance of all synapses that have cells 3 or 4 as presynaptic neuron.\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [3, 4]\")\nnet.select(edges=df.index).set(\"Ionotropic_gS\", 0.2)\n\n# Set synaptic conductance of all synapses that\n# 1) have cells 2 or 3 as presynaptic neuron and\n# 2) has cell 5 as postsynaptic neuron\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [2, 3]\")\ndf = df.query(\"post_global_cell_index == 5\")\nnet.select(edges=df.index).set(\"Ionotropic_gS\", 0.3)\n

In a previous tutorial you learned how to set parameters of a jx.Network. In that tutorial, we briefly mentioned the select() method which allowed to set individual synapses to particular values. In this tutorial, we will go into detail in how you can fully customize your Jaxley simulation.

Let\u2019s go!

import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.connect import fully_connect\nfrom jaxley.synapses import IonotropicSynapse\n
"},{"location":"tutorial/09_advanced_indexing/#preface-building-the-network","title":"Preface: Building the network","text":"

We first build a network consisting of six neurons, in the same way as we showed in the previous tutorials:

dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=2)\ncell = jx.Cell(branch, parents=[-1, 0])\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n
"},{"location":"tutorial/09_advanced_indexing/#setting-individual-synapse-parameters","title":"Setting individual synapse parameters","text":"

As always, you can use the .edges table to inspect synaptic parameters of the network:

net.edges\n
global_edge_index pre_global_comp_index post_global_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 0 13 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 1 1 0 19 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 2 2 0 20 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 3 3 4 12 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 4 4 4 16 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 5 5 4 21 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 6 6 8 13 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 7 7 8 17 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 8 8 8 21 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0

This table has nine rows, each corresponding to one synapse. This makes sense because we fully connected three neurons (0, 1, 2) to three other neurons (3, 4, 5), giving a total of 3x3=9 synapses.

You can modify parameters of individual synapses as follows:

net.select(edges=[3, 4, 5]).set(\"IonotropicSynapse_gS\", 0.2)\n

Above, we are modifying the synapses with indices [3, 4, 5] (i.e., the indices of the net.edges DataFrame). The resulting values are indeed changed:

net.edges.IonotropicSynapse_gS\n
0    0.0001\n1    0.0001\n2    0.0001\n3    0.2000\n4    0.2000\n5    0.2000\n6    0.0001\n7    0.0001\n8    0.0001\nName: IonotropicSynapse_gS, dtype: float64\n
"},{"location":"tutorial/09_advanced_indexing/#example-1-setting-synaptic-parameters-which-connect-particular-neurons","title":"Example 1: Setting synaptic parameters which connect particular neurons","text":"

This is great, but setting synaptic parameters just by their index can be exhausting, in particular in very large networks. Instead, we would want to, for example, set the maximal conductance of all synapses that connect from cell 0 or 1 to any other neuron.

In Jaxley, such customization can be achieved by filtering the .edges dataframe accordingly, as shown below:

net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1]\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.23)\n
net.edges.IonotropicSynapse_gS\n
0    0.2300\n1    0.2300\n2    0.2300\n3    0.2300\n4    0.2300\n5    0.2300\n6    0.0001\n7    0.0001\n8    0.0001\nName: IonotropicSynapse_gS, dtype: float64\n

Indeed, the first six synapses now have the value 0.23! Let\u2019s look at the individual lines to understand how this worked:

We want to set parameter by cell index. However, by default, the pre- or post-synaptic cell-indices are not listed in net.edges. We can add the cell index to the .edges dataframe by calling .copy_node_property_to_edges():

net.copy_node_property_to_edges(\"global_cell_index\")\n

After this, the pre- and post-synaptic cell indices are listed in net.edges as pre_global_cell_index and post_global_cell_index.

Next, we take .edges, which is a pandas DataFrame:

df = net.edges\n

We then modify this DataFrame to only contain those rows where the global cell index is in 0 or 1:

df = df.query(\"pre_global_cell_index in [0, 1]\")\n

For the above step, you use any column of the DataFrame to filter it (you can see all columns with df.columns). Note that, while we used .query() here, you can really filter the pandas DataFrame however you want. For example, the query above is identical to df = df[df[\"pre_global_cell_index\"].isin([0, 1])].

Finally, we use the .select() method, which returns a subset of the Network at the specified indices. This subset of the network can be modified with .set():

net.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.23)\n

"},{"location":"tutorial/09_advanced_indexing/#example-2-setting-parameters-given-pre-and-post-synaptic-cell-indices","title":"Example 2: Setting parameters given pre- and post-synaptic cell indices","text":"

Say you want to select all synapses that have cells 1 or 2 as presynaptic neuron and cell 4 or 5 as postsynaptic neuron.

net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n

Just like before, we can simply use .query() as already shown above. However, this time, call .query() to twice to filter by pre- and post-synaptic cell indices:

net.copy_node_property_to_edges(\"global_cell_index\")\n\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [1, 2]\")\ndf = df.query(\"post_global_cell_index in [4, 5]\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.3)\n
net.edges.IonotropicSynapse_gS\n
0    0.0001\n1    0.0001\n2    0.0001\n3    0.0001\n4    0.3000\n5    0.3000\n6    0.0001\n7    0.3000\n8    0.3000\nName: IonotropicSynapse_gS, dtype: float64\n
"},{"location":"tutorial/09_advanced_indexing/#example-3-applying-this-strategy-to-cell-level-parameters","title":"Example 3: Applying this strategy to cell level parameters","text":"

You had previously seen that you can modify parameters with, e.g., net.cell(0).set(...). However, if you need more flexibility than this, you can also use the above strategy to modify cell-level parameters:

net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\ndf = net.nodes\ndf = df.query(\"global_cell_index in [0, 1]\")\nnet.select(nodes=df.index).set(\"radius\", 0.1)\n
"},{"location":"tutorial/09_advanced_indexing/#example-4-flexibly-setting-parameters-based-on-their-groups","title":"Example 4: Flexibly setting parameters based on their groups","text":"

If you are using groups, as shown in this tutorial, then you can also use this for querying synapses. To demonstrate this, let\u2019s create a group of excitatory neurons (e.g., cells 0, 3, 5):

# Redefine network.\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.cell([0, 3, 5]).add_to_group(\"exc\")\n

Now, say we want all synapses that start from these excitatory neurons. You can do this as follows:

# First, we have to identify which cells are in the `exc` group.\nindices_of_excitatory_cells = net.exc.nodes[\"global_cell_index\"].unique().tolist()  # [0, 3, 5]\n\n# Then we can proceed as before:\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(f\"pre_global_cell_index in {indices_of_excitatory_cells}\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.4)\n
"},{"location":"tutorial/09_advanced_indexing/#example-5-setting-synaptic-parameters-based-on-properties-of-the-presynaptic-cell","title":"Example 5: Setting synaptic parameters based on properties of the presynaptic cell","text":"

Let\u2019s discuss one more example: Imagine we only want to modify those synapses whose presynaptic compartment has a sodium channel. Let\u2019s first add a sodium channel to some of the cells:

net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.cell(0).branch(0).comp(0).insert(Na())\nnet.cell(2).branch(1).comp(1).insert(Na())\n

Now, let us query which cells have the desired synapses:

df = net.nodes\ndf = df.query(\"Na\")\nindices_of_sodium_compartments = df[\"global_comp_index\"].unique().tolist()\n

indices_of_sodium_compartments lists all compartments which contained sodium:

print(indices_of_sodium_compartments)\n
[0, 11]\n

Then, we can proceed as always and filter for the global pre-synaptic compartment index:

df = net.edges\ndf = df.query(f\"pre_global_comp_index in {indices_of_sodium_compartments}\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.6)\n
net.edges.IonotropicSynapse_gS\n
0    0.6000\n1    0.6000\n2    0.6000\n3    0.0001\n4    0.0001\n5    0.0001\n6    0.0001\n7    0.0001\n8    0.0001\nName: IonotropicSynapse_gS, dtype: float64\n

Indeed, only synapses coming from the first neuron were modified (as its presynaptic compartment contained sodium), in contrast to synapses from neuron 2 (whose presynaptic compartment did not).

"},{"location":"tutorial/09_advanced_indexing/#summary","title":"Summary","text":"

In this tutorial, you learned how to fully customize your Jaxley simulation. This works by querying rows from the .edges DataFrame.

"},{"location":"tutorial/10_advanced_parameter_sharing/","title":"Synaptic parameter sharing","text":"

In this tutorial, you will learn how to:

  • flexibly share parameters of synapses

Here is a code snippet which you will learn to understand in this tutorial:

net = ...  # See tutorial on Basics of Jaxley.\n\n# The same parameter for all synapses\nnet.make_trainable(\"Ionotropic_gS\")\n\n# An individual parameter for every synapse.\nnet.select(edges=\"all\").make_trainable(\"Ionotropic_gS\")\n\n# Share synaptic conductances emerging from the same neurons.\nnet.copy_node_property_to_edges(\"cell_index\")\nsub_net = net.select(edges=[0, 1, 2])\nsub_net.edges[\"controlled_by_param\"] = sub_net.edges[\"pre_global_cell_index\"]\nsub_net.make_trainable(\"Ionotropic_gS\")\n

In a previous tutorial about training networks, we briefly touched on parameter sharing. In this tutorial, we will show you how you can flexibly share parameters within a network.

import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.connect import fully_connect\nfrom jaxley.synapses import IonotropicSynapse\n
"},{"location":"tutorial/10_advanced_parameter_sharing/#preface-building-the-network","title":"Preface: Building the network","text":"

We first build a network consisting of six neurons, in the same way as we showed in the previous tutorials:

dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0])\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n
"},{"location":"tutorial/10_advanced_parameter_sharing/#sharing-parameters-by-modifying-controlled_by_param","title":"Sharing parameters by modifying controlled_by_param","text":"
net.copy_node_property_to_edges(\"global_cell_index\")\n\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1, 2]\")\nsubnetwork = net.select(edges=df.index)\n\ndf = subnetwork.edges\ndf[\"controlled_by_param\"] = df[\"pre_global_cell_index\"]\nsubnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 3. Total number of trainable parameters: 3\n

Let\u2019s look at this line by line. First, we exactly follow the previous tutorial in selecting the synapses which we are interested in training (i.e., the ones whose presynaptic neuron has index 0, 1, 2):

df = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1, 2]\")\nsubnetwork = net.select(edges=df.index)\n

As second step, we enable parameter sharing. This is done by setting the controlled_by_param. Synapses that have the same value in controlled_by_param will be shared. Let\u2019s inspect controlled_by_param before we modify it:

subnetwork.edges[[\"pre_global_cell_index\", \"controlled_by_param\"]]\n
pre_global_cell_index controlled_by_param 0 0 0 1 0 1 2 0 2 3 1 3 4 1 4 5 1 5 6 2 6 7 2 7 8 2 8

Every synapse has a different value. Because of this, no synaptic parameters will be shared. To enable parameter sharing we override the controlled_by_param column with the presynaptic cell index:

df = subnetwork.edges\ndf[\"controlled_by_param\"] = df[\"pre_global_cell_index\"]\n
df[[\"pre_global_cell_index\", \"controlled_by_param\"]]\n
pre_global_cell_index controlled_by_param 0 0 0 1 0 0 2 0 0 3 1 1 4 1 1 5 1 1 6 2 2 7 2 2 8 2 2

Now, all we have to do is to make these synaptic parameters trainable with the make_trainable() method:

subnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 3. Total number of trainable parameters: 6\n

It correctly says that we added three parameters (because we have three cells, and we share individual synaptic parameters). We now have 6 trainable parameters in total (because we already added 3 trainable parameters above).

"},{"location":"tutorial/10_advanced_parameter_sharing/#a-more-involved-example-sharing-by-pre-and-post-synaptic-cell-type","title":"A more involved example: sharing by pre- and post-synaptic cell type","text":"

As an example, consider the following: We have a fully connected network of six cells. Each cell falls into one of three cell types:

from typing import Union, List\n
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell(\"all\"), net.cell(\"all\"), IonotropicSynapse())\n\nnet.cell([0, 1]).add_to_group(\"exc\")\nnet.cell([2, 3]).add_to_group(\"inh\")\nnet.cell([4, 5]).add_to_group(\"unknown\")\n

We want to make all synapses that start from excitatory or inhibitory neurons trainable. In addition, we want to use the same parameter for synapses if they have the same pre- and post-synaptic cell type.

To achieve this, we will first want a column in net.nodes which indicates the cell type.

for group, inds in net.groups.items():\n    net.nodes.loc[inds, \"cell_type\"] = group\n
net.nodes[\"cell_type\"]\n
0         exc\n1         exc\n2         exc\n3         exc\n4         exc\n5         exc\n6         exc\n7         exc\n8         inh\n9         inh\n10        inh\n11        inh\n12        inh\n13        inh\n14        inh\n15        inh\n16    unknown\n17    unknown\n18    unknown\n19    unknown\n20    unknown\n21    unknown\n22    unknown\n23    unknown\nName: cell_type, dtype: object\n

The cell_type is now part of the net.nodes. However, we would like to do parameter sharing of synapses based on the pre- and post-synaptic node values. To do so, we import the cell_type column into net.edges. To do this, we use the .copy_node_property_to_edges() which the name of the property you are copying from nodes:

net.copy_node_property_to_edges(\"cell_type\")\n

After this, you have columns in the .edges which indicate the pre- and post-synaptic cell type:

net.edges[[\"pre_cell_type\", \"post_cell_type\"]]\n
pre_cell_type post_cell_type 0 exc exc 1 exc exc 2 exc inh 3 exc inh 4 exc unknown 5 exc unknown 6 exc exc 7 exc exc 8 exc inh 9 exc inh 10 exc unknown 11 exc unknown 12 inh exc 13 inh exc 14 inh inh 15 inh inh 16 inh unknown 17 inh unknown 18 inh exc 19 inh exc 20 inh inh 21 inh inh 22 inh unknown 23 inh unknown 24 unknown exc 25 unknown exc 26 unknown inh 27 unknown inh 28 unknown unknown 29 unknown unknown 30 unknown exc 31 unknown exc 32 unknown inh 33 unknown inh 34 unknown unknown 35 unknown unknown

Next, we specify which parts of the network we actually want to change (in this case, all synapses which have excitatory or inhibitory presynaptic neurons):

df = net.edges\ndf = df.query(f\"pre_cell_type in ['exc', 'inh']\")\nprint(f\"There are {len(df)} synapses to be changed.\")\n\nsubnetwork = net.select(edges=df.index)\n
There are 24 synapses to be changed.\n

As the last step, we again have to specify parameter sharing by setting controlled_by_param. In this case, we want to share parameters that have the same pre- and post-synaptic neuron. We achieve this by grouping the synpases by their pre- and post-synaptic cell type (see pd.DataFrame.groupby for details):

# Step 6: use groupby to specify parameter sharing and make the parameters trainable.\nsubnetwork.edges[\"controlled_by_param\"] = subnetwork.edges.groupby([\"pre_cell_type\", \"post_cell_type\"]).ngroup()\nsubnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 6. Total number of trainable parameters: 6\n

This created six trainable parameters, which makes sense as we have two types of pre-synaptic neurons (excitatory and inhibitory) and each has three options for the postsynaptic neuron (pre, post, unknown).

"},{"location":"tutorial/10_advanced_parameter_sharing/#summary","title":"Summary","text":"

In this tutorial, you learned how you can flexibly share synaptic parameters. This works by first using select() to identify which synapses to make trainable, and by then modifying controlled_by_param to customize parameter sharing.

"}]} \ No newline at end of file