From d5d6166d72e71e59c9cb4edea735305ef25db524 Mon Sep 17 00:00:00 2001 From: Andrew Davison Date: Tue, 30 Apr 2024 11:12:45 +0200 Subject: [PATCH] Harmonize spike train creation, following addition of SpikeTrainList (fix error when using MPI and pyNN.nest) --- doc/pyplots/neo_example.py | 4 ++-- pyNN/recording/__init__.py | 24 +++++++++++++----------- pyNN/serialization/sonata.py | 2 +- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/doc/pyplots/neo_example.py b/doc/pyplots/neo_example.py index 7a625af2e..fc3159fb2 100644 --- a/doc/pyplots/neo_example.py +++ b/doc/pyplots/neo_example.py @@ -32,14 +32,14 @@ def plot_spiketrains(segment): for spiketrain in segment.spiketrains: - y = np.ones_like(spiketrain) * spiketrain.annotations['source_id'] + y = np.ones_like(spiketrain) * spiketrain.annotations['channel_id'] plt.plot(spiketrain, y, '.') plt.ylabel(segment.name) plt.setp(plt.gca().get_xticklabels(), visible=False) def plot_signal(signal, index, colour='b'): - label = "Neuron %d" % signal.annotations['source_ids'][index] + label = "Neuron %d" % signal.annotations['channel_ids'][index] plt.plot(signal.times, signal[:, index], colour, label=label) plt.ylabel("%s (%s)" % (signal.name, signal.units._dimensionality.string)) plt.setp(plt.gca().get_xticklabels(), visible=False) diff --git a/pyNN/recording/__init__.py b/pyNN/recording/__init__.py index 4e3b6d79a..666309591 100644 --- a/pyNN/recording/__init__.py +++ b/pyNN/recording/__init__.py @@ -105,7 +105,7 @@ def gather_blocks(data, ordered=True): if ordered: for segment in merged.segments: ordered_spiketrains = sorted( - segment.spiketrains, key=lambda s: s.annotations['source_id']) + segment.spiketrains, key=lambda s: s.annotations['channel_id']) segment.spiketrains = ordered_spiketrains return merged @@ -319,7 +319,7 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False): t_stop=t_stop, units='ms', source_population=self.population.label, - source_id=int(id), + channel_id=int(id), source_index=self.population.id_to_index(int(id))) ) for train in segment.spiketrains: @@ -333,13 +333,15 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False): mask = times <= t_stop times = times[mask] id_array = id_array[mask] + channel_ids = np.array(sids, dtype=int) segment.spiketrains = neo.spiketrainlist.SpikeTrainList.from_spike_time_array( times, id_array, - np.array(sids, dtype=int), + channel_ids, t_stop=t_stop, units="ms", t_start=self._recording_start_time, - source_population=self.population.label + source_population=self.population.label, + source_index=self.population.id_to_index(channel_ids) ) segment.spiketrains.segment = segment else: @@ -349,7 +351,7 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False): if signal_array.size > 0: # may be empty if none of the recorded cells are on this MPI node units = self.population.find_units(variable) - source_ids = np.fromiter(ids, dtype=int) + channel_ids = np.fromiter(ids, dtype=int) if len(ids) == signal_array.shape[1]: # one channel per neuron channel_index = np.array([self.population.id_to_index(id) for id in ids]) else: # multiple recording locations per neuron @@ -372,11 +374,11 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False): units=units, time_units=pq.ms, name=signal_name, - source_ids=[source_id], + channel_ids=[channel_id], source_population=self.population.label, array_annotations={"channel_index": [i]} ) - for i, source_id in zip(channel_index, source_ids) + for i, channel_id in zip(channel_index, channel_ids) ] else: # all channels have the same sample times @@ -384,7 +386,7 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False): signals = [ neo.IrregularlySampledSignal( times_array, signal_array, units=units, time_units=pq.ms, - name=signal_name, source_ids=source_ids, + name=signal_name, channel_ids=channel_ids, source_population=self.population.label, array_annotations={"channel_index": channel_index} ) @@ -402,13 +404,13 @@ def _get_current_segment(self, filter_ids=None, variables='all', clear=False): units=units, t_start=t_start, sampling_period=sampling_period, - name=signal_name, source_ids=source_ids, + name=signal_name, channel_ids=channel_ids, source_population=self.population.label, array_annotations={"channel_index": channel_index} ) assert signal.t_stop - current_time - 2 * sampling_period < 1e-10 - logger.debug("%d **** ids=%s, channels=%s", mpi_node, - source_ids, signal.array_annotations["channel_index"]) + logger.debug("%d **** channel_ids=%s, channel_index=%s", mpi_node, + channel_ids, signal.array_annotations["channel_index"]) segment.analogsignals.append(signal) signal.segment = segment return segment diff --git a/pyNN/serialization/sonata.py b/pyNN/serialization/sonata.py index 98b2c4654..51e9de498 100644 --- a/pyNN/serialization/sonata.py +++ b/pyNN/serialization/sonata.py @@ -98,7 +98,7 @@ def read(self): t_stop=spike_times.max() + 1.0, t_start=0.0, units='ms', - source_id=gid) + channel_id=gid) ) block.segments.append(segment) return [block]