Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Network plot #104

Merged
merged 31 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
69979f8
type annotate get_precise_record_dt
anilbey Oct 5, 2023
19f1ae4
update docs of Cell constructor
anilbey Oct 5, 2023
3c47384
fix get_section_id attribute error
anilbey Oct 5, 2023
08e4bb5
update Cell::area docstring
anilbey Oct 5, 2023
d32d208
add unit test for Cell::area
anilbey Oct 5, 2023
ee58b91
synlocation_to_segx to always return float
anilbey Oct 5, 2023
00a578f
merge main into terminology
anilbey Oct 13, 2023
17acb22
make use_random123_stochkv non nullable
anilbey Oct 13, 2023
b473ac3
add return type annotations
anilbey Oct 13, 2023
b388465
Merge branch 'main' into terminology
anilbey Oct 24, 2023
132fd55
update until pre_gid_synapse_ids
anilbey Oct 25, 2023
cea9219
remove ssim.connections
anilbey Oct 25, 2023
e4582b8
keep cell_id inside cell, not gid
anilbey Oct 26, 2023
7d0baca
keep cell_id in synapse, not gid
anilbey Oct 26, 2023
e5d10e8
draft: add network plot
anilbey Oct 26, 2023
0693ee6
build and plot the graph within the 'sonata-network' notebook
ilkilic Nov 1, 2023
0b29cad
Merge branch 'main' into network-plot
ilkilic Nov 22, 2023
1066e3d
fix kernelspec in sonata-network.ipynb
ilkilic Nov 22, 2023
98f0b42
lint fix
ilkilic Nov 22, 2023
ad5c5e6
Add 'networkx' to package dependencies in setup.py
ilkilic Nov 22, 2023
b6d5a13
Add mypy ignore comments for matplotlib attrs
ilkilic Nov 22, 2023
60ffb6f
add unit tests for graph.py
ilkilic Dec 4, 2023
7cea8f4
update NetworkX minimum version requirement
ilkilic Dec 4, 2023
fef14e9
add comments in sonata-network notebook
ilkilic Dec 4, 2023
4bd15e6
Merge branch 'main' into network-plot
ilkilic Dec 4, 2023
9fbc539
lint fix
ilkilic Dec 4, 2023
35e1286
downgrade networkx to 3.1 for compatibility
ilkilic Dec 4, 2023
020f658
fix variable name
ilkilic Dec 4, 2023
d3626a1
fix template_format to v6
ilkilic Dec 4, 2023
008e150
refactor tests to use pytest framework
ilkilic Dec 4, 2023
54a13cf
lint fix
ilkilic Dec 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions bluecellulab/cell/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@
class Cell(InjectableMixin, PlottableMixin):
"""Represents a Cell object."""

max_id = 0
anilbey marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self,
template_path: str | Path,
morphology_path: str | Path,
gid: int = 0,
cell_id: Optional[CellId] = None,
record_dt: Optional[float] = None,
template_format: str = "v5",
emodel_properties: Optional[EmodelProperties] = None,
Expand All @@ -74,6 +76,10 @@ def __init__(self,
object used by the Cell. Defaults to None.
"""
super().__init__()
if cell_id is None:
cell_id = CellId("", Cell.max_id)
Cell.max_id += 1
self.cell_id = cell_id
# Persistent objects, like clamps, that exist as long
# as the object exists
self.persistent: list[HocObjectType] = []
Expand All @@ -83,7 +89,7 @@ def __init__(self,
# Load the template
neuron_template = NeuronTemplate(template_path, morphology_path)
self.template_id = neuron_template.template_name # useful to map NEURON and python objects
self.cell = neuron_template.get_cell(template_format, gid, emodel_properties)
self.cell = neuron_template.get_cell(template_format, self.cell_id.id, emodel_properties)
self.soma = public_hoc_cell(self.cell).soma[0]
# WARNING: this finitialize 'must' be here, otherwhise the
# diameters of the loaded morph are wrong
Expand All @@ -92,8 +98,7 @@ def __init__(self,
self.cellname = neuron.h.secname(sec=self.soma).split(".")[0]

# Set the gid of the cell
public_hoc_cell(self.cell).gid = gid
self.gid = gid
self.cell.getCell().gid = self.cell_id.id

if rng_settings is None:
self.rng_settings = RNGSettings("Random123") # SONATA value
Expand Down Expand Up @@ -212,17 +217,17 @@ def re_init_rng(self, use_random123_stochkv: bool = False) -> None:
for section in self.somatic:
for seg in section:
neuron.h.setdata_StochKv(seg.x, sec=section)
neuron.h.setRNG_StochKv(channel_id, self.gid)
neuron.h.setRNG_StochKv(channel_id, self.cell_id.id)
channel_id += 1
for section in self.basal:
for seg in section:
neuron.h.setdata_StochKv(seg.x, sec=section)
neuron.h.setRNG_StochKv(channel_id, self.gid)
neuron.h.setRNG_StochKv(channel_id, self.cell_id.id)
channel_id += 1
for section in self.apical:
for seg in section:
neuron.h.setdata_StochKv(seg.x, sec=section)
neuron.h.setRNG_StochKv(channel_id, self.gid)
neuron.h.setRNG_StochKv(channel_id, self.cell_id.id)
channel_id += 1
else:
self.cell.re_init_rng()
Expand Down Expand Up @@ -444,7 +449,7 @@ def add_replay_synapse(self,

self.synapses[synapse_id] = synapse

logger.debug(f'Added synapse to cell {self.gid}')
logger.debug(f'Added synapse to cell {self.cell_id.id}')

def add_replay_delayed_weight(
self, sid: tuple[str, int], delay: float, weight: float
Expand Down Expand Up @@ -590,10 +595,10 @@ def add_replay_minis(self,
+ self.rng_settings.minis_seed
self.ips[synapse_id].setRNGs(
sid + 200,
self.gid + 250,
self.cell_id.id + 250,
seed2 + 300,
sid + 200,
self.gid + 250,
self.cell_id.id + 250,
seed2 + 350)
else:
exprng = bluecellulab.neuron.h.Random()
Expand All @@ -604,18 +609,18 @@ def add_replay_minis(self,

if self.rng_settings.mode == 'Compatibility':
exp_seed1 = sid * 100000 + 200
exp_seed2 = self.gid + 250 + base_seed + \
exp_seed2 = self.cell_id.id + 250 + base_seed + \
self.rng_settings.minis_seed
uniform_seed1 = sid * 100000 + 300
uniform_seed2 = self.gid + 250 + base_seed + \
uniform_seed2 = self.cell_id.id + 250 + base_seed + \
self.rng_settings.minis_seed
elif self.rng_settings.mode == "UpdatedMCell":
exp_seed1 = sid * 1000 + 200
exp_seed2 = source_popid * 16777216 + self.gid + 250 + \
exp_seed2 = source_popid * 16777216 + self.cell_id.id + 250 + \
base_seed + \
self.rng_settings.minis_seed
uniform_seed1 = sid * 1000 + 300
uniform_seed2 = source_popid * 16777216 + self.gid + 250 \
uniform_seed2 = source_popid * 16777216 + self.cell_id.id + 250 \
+ base_seed + \
self.rng_settings.minis_seed
else:
Expand Down Expand Up @@ -788,7 +793,7 @@ def add_synapse_replay(
spike_location=spike_location,
)
logger.debug(
f"Added synapse replay from {pre_gid} to {self.gid}, {synapse_id}"
f"Added synapse replay from {pre_gid} to {self.cell_id.id}, {synapse_id}"
)

self.connections[synapse_id] = connection
Expand Down
10 changes: 5 additions & 5 deletions bluecellulab/cell/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,20 @@ def add_voltage_clamp(
def _get_noise_step_rand(self, noisestim_count):
"""Return rng for noise step stimulus."""
if self.rng_settings.mode == "Compatibility":
rng = bluecellulab.neuron.h.Random(self.gid + noisestim_count)
rng = bluecellulab.neuron.h.Random(self.cell_id.id + noisestim_count)
elif self.rng_settings.mode == "UpdatedMCell":
rng = bluecellulab.neuron.h.Random()
rng.MCellRan4(
noisestim_count * 10000 + 100,
self.rng_settings.base_seed +
self.rng_settings.stimulus_seed +
self.gid * 1000)
self.cell_id.id * 1000)
elif self.rng_settings.mode == "Random123":
rng = bluecellulab.neuron.h.Random()
rng.Random123(
noisestim_count + 100,
self.rng_settings.stimulus_seed + 500,
self.gid + 300)
self.cell_id.id + 300)

self.persistent.append(rng)
return rng
Expand Down Expand Up @@ -235,7 +235,7 @@ def _get_ornstein_uhlenbeck_rand(self, stim_count, seed):
if self.rng_settings.mode == "Random123":
seed1 = stim_count + 2997 # stimulus block
seed2 = self.rng_settings.stimulus_seed + 291204 # stimulus type
seed3 = self.gid + 123 if seed is None else seed # GID
seed3 = self.cell_id.id + 123 if seed is None else seed # GID
logger.debug("Using ornstein_uhlenbeck process seeds %d %d %d" %
(seed1, seed2, seed3))
rng = bluecellulab.neuron.h.Random()
Expand All @@ -251,7 +251,7 @@ def _get_shotnoise_step_rand(self, shotnoise_stim_count, seed=None):
if self.rng_settings.mode == "Random123":
seed1 = shotnoise_stim_count + 2997
seed2 = self.rng_settings.stimulus_seed + 19216
seed3 = self.gid + 123 if seed is None else seed
seed3 = self.cell_id.id + 123 if seed is None else seed
logger.debug("Using shot noise seeds %d %d %d" %
(seed1, seed2, seed3))
rng = bluecellulab.neuron.h.Random()
Expand Down
7 changes: 4 additions & 3 deletions bluecellulab/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

import bluecellulab
from bluecellulab.cell.core import Cell
from bluecellulab.circuit import SynapseProperty


Expand All @@ -28,7 +29,7 @@ def __init__(
self,
post_synapse,
pre_spiketrain: Optional[np.ndarray] = None,
pre_cell=None,
pre_cell: Optional[Cell] = None,
stim_dt=None,
parallel_context=None,
spike_threshold: float = -30.0,
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(
self.post_netcon = self.pre_cell.create_netcon_spikedetector(
self.post_synapse.hsynapse, location=spike_location,
threshold=spike_threshold) if self.pc is None else \
self.pc.gid_connect(self.pre_cell.gid, self.post_synapse.hsynapse)
self.pc.gid_connect(self.pre_cell.cell_id.id, self.post_synapse.hsynapse)
self.post_netcon.weight[0] = self.post_netcon_weight
self.post_netcon.delay = self.post_netcon_delay
self.post_netcon.threshold = spike_threshold
Expand All @@ -94,7 +95,7 @@ def info_dict(self):
connection_dict = {}

connection_dict['pre_cell_id'] = self.post_synapse.pre_gid
connection_dict['post_cell_id'] = self.post_synapse.post_gid
connection_dict['post_cell_id'] = self.post_synapse.post_cell_id.id
connection_dict['post_synapse_id'] = self.post_synapse.syn_id.sid

connection_dict['post_netcon'] = {}
Expand Down
85 changes: 85 additions & 0 deletions bluecellulab/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Graph representation of Cells and Synapses."""

from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np


from bluecellulab.cell.cell_dict import CellDict


def build_graph(cells: CellDict) -> nx.DiGraph:
G = nx.DiGraph()

# Add nodes (cells) to the graph
for cell_id, cell in cells.items():
G.add_node(cell_id, label=str(cell_id.id), population=cell_id.population_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, when fetching the population_name from G afterwards, population_name is not used, but cell_id.population_name.
Is it useful to add a population_name to G if you only fetch it from G.cell_id?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While plotting the plot in the later step the population name gives the color of the node and it is read from G.


# Extract and add edges (connections) to the graph from each cell
for cell_id, cell in cells.items():
for connection in cell.connections.values():
# Check if pre_cell exists for the connection
if connection.pre_cell is None:
continue

# Source is the pre_cell from the connection
source_cell_id = connection.pre_cell.cell_id

# Target is the post-synapse cell from the connection
target_cell_id = connection.post_synapse.post_cell_id

# Check if both source and target cells are within the current cell collection
if source_cell_id in cells and target_cell_id in cells:
G.add_edge(source_cell_id, target_cell_id, weight=connection.weight)

return G


def plot_graph(G: nx.Graph, node_size: float = 400, edge_width: float = 0.4, node_distance: float = 1.6):
# Extract unique populations from the graph nodes
populations = list(set([cell_id.population_name for cell_id in G.nodes()]))

# Create a color map for each population
color_map = plt.cm.tab20(np.linspace(0, 1, len(populations))) # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to be sure, do we expect populations to be <= 20 in size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely. The populations here are referring to multiple circuits which is usually <=3.

population_color = dict(zip(populations, color_map))

# Create node colors based on their population
node_colors = [population_color[node.population_name] for node in G.nodes()]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, are we getting the population name property of each node, or are we getting the population of the cell_id in each node? I am asking because at line 42, we are doing the same thing, but the variable is named cell_id.
It would be good to be precise about what we are doing, for readability

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually line 42 is constructing a set and removing the duplicates. This line 49 constructs a list with duplicates.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I know. But what I meant, is that even though in line 42 it is named cell_id, and here it is named node, it looks like it is the same variable we access. In G, we have something like this for each node (cell_id, id, population_name), and my question is, when we are accessing population_name, are we getting the population_name in the node (the 3rd element in my set), or are we accessing cell_id.population_name (via the 1st element in my set).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thanks, a good catch. The "label" and "population" attributes of the node are never used in G. There is duplicated information. Instead of this:
G.add_node(cell_id, label=str(cell_id.id), population=cell_id.population_name)
we can just write this:
G.add_node(cell_id)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always the 1st element in your set is used, the 3rd is never used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks for clarifying

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing, let me write a patch


# Extract weights for edge color mapping
edge_weights = [d['weight'] for _, _, d in G.edges(data=True)]
edge_colors = plt.cm.Greens(np.interp(edge_weights, (min(edge_weights), max(edge_weights)), (0.3, 1))) # type: ignore[attr-defined]

# Create positions using spring layout for the entire graph
pos = nx.spring_layout(G, k=node_distance)

# Create labels only for the node ID
labels = {node: node.id for node in G.nodes()}

# Create a figure and axis for the drawing
fig, ax = plt.subplots(figsize=(6, 5))

# Draw the graph
nx.draw(G, pos, with_labels=True, labels=labels, node_color=node_colors,
edge_color=edge_colors, width=edge_width, node_size=node_size, ax=ax, connectionstyle='arc3, rad = 0.1')

# Draw directed edges
nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_width, ax=ax, arrowstyle='-|>', arrowsize=20, connectionstyle='arc3, rad = 0.1')

# Create a legend
for population, color in population_color.items():
plt.plot([0], [0], color=color, label=population)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good to know. In this plot, the former approach avoids the overlaps in legends.
With the patches approach it looks like we need to explicitly specify the location of each legend.

plt_before

plt_with_patch

plt.legend(loc="upper left", bbox_to_anchor=(-0.1, 1.05)) # Adjust these values as needed

# Add a colorbar for edge weights
sm = ScalarMappable(cmap="Greens", norm=Normalize(vmin=min(edge_weights), vmax=max(edge_weights)))
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, orientation="vertical", fraction=0.03, pad=0.04)
cbar.set_label('Synaptic Strength')

# Add text at the bottom of the figure
plt.figtext(0.5, 0.01, "Network of simulated cells", ha="center", fontsize=10, va="bottom")

plt.show()
2 changes: 1 addition & 1 deletion bluecellulab/simulation/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def run(

if self.pc is not None:
for cell in self.cells:
self.pc.prcellstate(cell.gid, f"bluecellulab_t={bluecellulab.neuron.h.t}")
self.pc.prcellstate(cell.cell_id.id, f"bluecellulab_t={bluecellulab.neuron.h.t}")

try:
neuron.h.continuerun(neuron.h.tstop)
Expand Down
8 changes: 2 additions & 6 deletions bluecellulab/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from __future__ import annotations
from collections.abc import Iterable
from collections import defaultdict
from pathlib import Path
from typing import Optional
import logging
Expand Down Expand Up @@ -100,9 +99,6 @@ def __init__(
self.cells: CellDict = CellDict()

self.gids_instantiated = False
self.connections: defaultdict = defaultdict(
lambda: defaultdict(lambda: None)
)

# Make sure tstop is set correctly, because it is used by the
# TStim noise stimulus
Expand Down Expand Up @@ -732,7 +728,7 @@ def fetch_cell_kwargs(self, cell_id: CellId) -> dict:
cell_kwargs = {
'template_path': self.circuit_access.emodel_path(cell_id),
'morphology_path': self.circuit_access.morph_filepath(cell_id),
'gid': cell_id.id,
'cell_id': cell_id,
'record_dt': self.record_dt,
'rng_settings': self.rng_settings,
'template_format': self.circuit_access.get_template_format(),
Expand All @@ -746,7 +742,7 @@ def create_cell_from_circuit(self, cell_id: CellId) -> bluecellulab.Cell:
cell_kwargs = self.fetch_cell_kwargs(cell_id)
return bluecellulab.Cell(template_path=cell_kwargs['template_path'],
morphology_path=cell_kwargs['morphology_path'],
gid=cell_kwargs['gid'],
cell_id=cell_kwargs['cell_id'],
record_dt=cell_kwargs['record_dt'],
rng_settings=cell_kwargs['rng_settings'],
template_format=cell_kwargs['template_format'],
Expand Down
6 changes: 3 additions & 3 deletions bluecellulab/synapse/synapse_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def create_synapse(
randomize_gaba_risetime = condition_parameters.randomize_gaba_rise_time
else:
randomize_gaba_risetime = True
synapse = GabaabSynapse(cell.rng_settings, cell.gid, syn_hoc_args, syn_id, syn_description,
synapse = GabaabSynapse(cell.rng_settings, cell.cell_id, syn_hoc_args, syn_id, syn_description,
popids, extracellular_calcium, randomize_gaba_risetime)
elif syn_type == SynapseType.AMPANMDA:
synapse = AmpanmdaSynapse(cell.rng_settings, cell.gid, syn_hoc_args, syn_id, syn_description,
synapse = AmpanmdaSynapse(cell.rng_settings, cell.cell_id, syn_hoc_args, syn_id, syn_description,
popids, extracellular_calcium)
else:
synapse = GluSynapse(cell.rng_settings, cell.gid, syn_hoc_args, syn_id, syn_description,
synapse = GluSynapse(cell.rng_settings, cell.cell_id, syn_hoc_args, syn_id, syn_description,
popids, extracellular_calcium)

synapse = cls.apply_connection_modifiers(connection_modifiers, synapse)
Expand Down
Loading