Skip to content

Commit

Permalink
Merge pull request #797 from apdavison/fix-channel-id
Browse files Browse the repository at this point in the history
Harmonize spike train annotations
  • Loading branch information
apdavison authored Apr 30, 2024
2 parents 1fbf10a + d5d6166 commit 5b6af57
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
4 changes: 2 additions & 2 deletions doc/pyplots/neo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@

def plot_spiketrains(segment):
for spiketrain in segment.spiketrains:
y = np.ones_like(spiketrain) * spiketrain.annotations['source_id']
y = np.ones_like(spiketrain) * spiketrain.annotations['channel_id']
plt.plot(spiketrain, y, '.')
plt.ylabel(segment.name)
plt.setp(plt.gca().get_xticklabels(), visible=False)


def plot_signal(signal, index, colour='b'):
label = "Neuron %d" % signal.annotations['source_ids'][index]
label = "Neuron %d" % signal.annotations['channel_ids'][index]
plt.plot(signal.times, signal[:, index], colour, label=label)
plt.ylabel("%s (%s)" % (signal.name, signal.units._dimensionality.string))
plt.setp(plt.gca().get_xticklabels(), visible=False)
Expand Down
24 changes: 13 additions & 11 deletions pyNN/recording/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def gather_blocks(data, ordered=True):
if ordered:
for segment in merged.segments:
ordered_spiketrains = sorted(
segment.spiketrains, key=lambda s: s.annotations['source_id'])
segment.spiketrains, key=lambda s: s.annotations['channel_id'])
segment.spiketrains = ordered_spiketrains
return merged

Expand Down Expand Up @@ -319,7 +319,7 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False):
t_stop=t_stop,
units='ms',
source_population=self.population.label,
source_id=int(id),
channel_id=int(id),
source_index=self.population.id_to_index(int(id)))
)
for train in segment.spiketrains:
Expand All @@ -333,13 +333,15 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False):
mask = times <= t_stop
times = times[mask]
id_array = id_array[mask]
channel_ids = np.array(sids, dtype=int)
segment.spiketrains = neo.spiketrainlist.SpikeTrainList.from_spike_time_array(
times, id_array,
np.array(sids, dtype=int),
channel_ids,
t_stop=t_stop,
units="ms",
t_start=self._recording_start_time,
source_population=self.population.label
source_population=self.population.label,
source_index=self.population.id_to_index(channel_ids)
)
segment.spiketrains.segment = segment
else:
Expand All @@ -349,7 +351,7 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False):
if signal_array.size > 0:
# may be empty if none of the recorded cells are on this MPI node
units = self.population.find_units(variable)
source_ids = np.fromiter(ids, dtype=int)
channel_ids = np.fromiter(ids, dtype=int)
if len(ids) == signal_array.shape[1]: # one channel per neuron
channel_index = np.array([self.population.id_to_index(id) for id in ids])
else: # multiple recording locations per neuron
Expand All @@ -372,19 +374,19 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False):
units=units,
time_units=pq.ms,
name=signal_name,
source_ids=[source_id],
channel_ids=[channel_id],
source_population=self.population.label,
array_annotations={"channel_index": [i]}
)
for i, source_id in zip(channel_index, source_ids)
for i, channel_id in zip(channel_index, channel_ids)
]
else:
# all channels have the same sample times
assert signal_array.shape[0] == times_array.size
signals = [
neo.IrregularlySampledSignal(
times_array, signal_array, units=units, time_units=pq.ms,
name=signal_name, source_ids=source_ids,
name=signal_name, channel_ids=channel_ids,
source_population=self.population.label,
array_annotations={"channel_index": channel_index}
)
Expand All @@ -402,13 +404,13 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False):
units=units,
t_start=t_start,
sampling_period=sampling_period,
name=signal_name, source_ids=source_ids,
name=signal_name, channel_ids=channel_ids,
source_population=self.population.label,
array_annotations={"channel_index": channel_index}
)
assert signal.t_stop - current_time - 2 * sampling_period < 1e-10
logger.debug("%d **** ids=%s, channels=%s", mpi_node,
source_ids, signal.array_annotations["channel_index"])
logger.debug("%d **** channel_ids=%s, channel_index=%s", mpi_node,
channel_ids, signal.array_annotations["channel_index"])
segment.analogsignals.append(signal)
signal.segment = segment
return segment
Expand Down
2 changes: 1 addition & 1 deletion pyNN/serialization/sonata.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def read(self):
t_stop=spike_times.max() + 1.0,
t_start=0.0,
units='ms',
source_id=gid)
channel_id=gid)
)
block.segments.append(segment)
return [block]
Expand Down

0 comments on commit 5b6af57

Please sign in to comment.