diff --git a/bluecellulab/analysis/inject_sequence.py b/bluecellulab/analysis/inject_sequence.py index d4d52d6d..598d17c1 100644 --- a/bluecellulab/analysis/inject_sequence.py +++ b/bluecellulab/analysis/inject_sequence.py @@ -74,6 +74,7 @@ def apply_multiple_stimuli( cell: Cell, stimulus_name: StimulusName, amplitudes: Sequence[float], + threshold_based: bool = True, section_name: str | None = None, segment: float = 0.5, n_processes: int | None = None, @@ -84,6 +85,8 @@ def apply_multiple_stimuli( cell: The cell to which the stimuli are applied. stimulus_name: The name of the stimulus to apply. amplitudes: The amplitudes of the stimuli to apply. + threshold_based: Whether to consider amplitudes to be + threshold percentages or to be raw amplitudes. section_name: Section name of the cell where the stimuli are applied. If None, the stimuli are applied at the soma[0] of the cell. segment: The segment of the section where the stimuli are applied. @@ -103,18 +106,37 @@ def apply_multiple_stimuli( # Prepare arguments for each stimulus for amplitude in amplitudes: + if threshold_based: + thres_perc = amplitude + amp = None + else: + thres_perc = None + amp = amplitude + if stimulus_name == StimulusName.AP_WAVEFORM: - stimulus = stim_factory.ap_waveform(threshold_current=cell.threshold, threshold_percentage=amplitude) + stimulus = stim_factory.ap_waveform( + threshold_current=cell.threshold, threshold_percentage=thres_perc, amplitude=amp + ) elif stimulus_name == StimulusName.IDREST: - stimulus = stim_factory.idrest(threshold_current=cell.threshold, threshold_percentage=amplitude) + stimulus = stim_factory.idrest( + threshold_current=cell.threshold, threshold_percentage=thres_perc, amplitude=amp + ) elif stimulus_name == StimulusName.IV: - stimulus = stim_factory.iv(threshold_current=cell.threshold, threshold_percentage=amplitude) + stimulus = stim_factory.iv( + threshold_current=cell.threshold, threshold_percentage=thres_perc, amplitude=amp + ) elif stimulus_name == StimulusName.FIRE_PATTERN: - stimulus = stim_factory.fire_pattern(threshold_current=cell.threshold, threshold_percentage=amplitude) + stimulus = stim_factory.fire_pattern( + threshold_current=cell.threshold, threshold_percentage=thres_perc, amplitude=amp + ) elif stimulus_name == StimulusName.POS_CHEOPS: - stimulus = stim_factory.pos_cheops(threshold_current=cell.threshold, threshold_percentage=amplitude) + stimulus = stim_factory.pos_cheops( + threshold_current=cell.threshold, threshold_percentage=thres_perc, amplitude=amp + ) elif stimulus_name == StimulusName.NEG_CHEOPS: - stimulus = stim_factory.neg_cheops(threshold_current=cell.threshold, threshold_percentage=amplitude) + stimulus = stim_factory.neg_cheops( + threshold_current=cell.threshold, threshold_percentage=thres_perc, amplitude=amp + ) else: raise ValueError("Unknown stimulus name.") diff --git a/bluecellulab/stimulus/factory.py b/bluecellulab/stimulus/factory.py index 1e666d94..5e8a1c95 100644 --- a/bluecellulab/stimulus/factory.py +++ b/bluecellulab/stimulus/factory.py @@ -1,8 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Optional +import logging import matplotlib.pyplot as plt import numpy as np +logger = logging.getLogger(__name__) + class Stimulus(ABC): def __init__(self, dt: float) -> None: @@ -299,99 +303,182 @@ def ramp( ) def ap_waveform( - self, threshold_current: float, threshold_percentage: float = 220.0 + self, + threshold_current: Optional[float] = None, + threshold_percentage: Optional[float] = 220.0, + amplitude: Optional[float] = None, ) -> Stimulus: """Returns the APWaveform Stimulus object, a type of Step stimulus. Args: threshold_current: The threshold current of the Cell. threshold_percentage: Percentage of desired threshold_current amplification. + amplitude: Raw amplitude of input current. """ pre_delay = 250.0 duration = 50.0 post_delay = 250.0 - return Step.threshold_based( - self.dt, - pre_delay=pre_delay, - duration=duration, - post_delay=post_delay, - threshold_current=threshold_current, - threshold_percentage=threshold_percentage, - ) + + if amplitude is not None: + if threshold_current is not None and threshold_current != 0 and threshold_percentage is not None: + logger.info( + "amplitude, threshold_current and threshold_percentage are all set in ap_waveform." + " Will only keep amplitude value." + ) + return Step.amplitude_based( + self.dt, + pre_delay=pre_delay, + duration=duration, + post_delay=post_delay, + amplitude=amplitude, + ) + + if threshold_current is not None and threshold_current != 0 and threshold_percentage is not None: + return Step.threshold_based( + self.dt, + pre_delay=pre_delay, + duration=duration, + post_delay=post_delay, + threshold_current=threshold_current, + threshold_percentage=threshold_percentage, + ) + + raise TypeError("You have to give either threshold_current or amplitude") def idrest( self, - threshold_current: float, - threshold_percentage: float = 200.0, + threshold_current: Optional[float] = None, + threshold_percentage: Optional[float] = 200.0, + amplitude: Optional[float] = None, ) -> Stimulus: """Returns the IDRest Stimulus object, a type of Step stimulus. Args: threshold_current: The threshold current of the Cell. threshold_percentage: Percentage of desired threshold_current amplification. + amplitude: Raw amplitude of input current. """ pre_delay = 250.0 duration = 1350.0 post_delay = 250.0 - return Step.threshold_based( - self.dt, - pre_delay=pre_delay, - duration=duration, - post_delay=post_delay, - threshold_current=threshold_current, - threshold_percentage=threshold_percentage, - ) + + if amplitude is not None: + if threshold_current is not None and threshold_current != 0 and threshold_percentage is not None: + logger.info( + "amplitude, threshold_current and threshold_percentage are all set in idrest." + " Will only keep amplitude value." + ) + return Step.amplitude_based( + self.dt, + pre_delay=pre_delay, + duration=duration, + post_delay=post_delay, + amplitude=amplitude, + ) + + if threshold_current is not None and threshold_current != 0 and threshold_percentage is not None: + return Step.threshold_based( + self.dt, + pre_delay=pre_delay, + duration=duration, + post_delay=post_delay, + threshold_current=threshold_current, + threshold_percentage=threshold_percentage, + ) + + raise TypeError("You have to give either threshold_current or amplitude") def iv( self, - threshold_current: float, - threshold_percentage: float = -40.0, + threshold_current: Optional[float] = None, + threshold_percentage: Optional[float] = -40.0, + amplitude: Optional[float] = None, ) -> Stimulus: """Returns the IV Stimulus object, a type of Step stimulus. Args: threshold_current: The threshold current of the Cell. threshold_percentage: Percentage of desired threshold_current amplification. + amplitude: Raw amplitude of input current. """ pre_delay = 250.0 duration = 3000.0 post_delay = 250.0 - return Step.threshold_based( - self.dt, - pre_delay=pre_delay, - duration=duration, - post_delay=post_delay, - threshold_current=threshold_current, - threshold_percentage=threshold_percentage, - ) + + if amplitude is not None: + if threshold_current is not None and threshold_current != 0 and threshold_percentage is not None: + logger.info( + "amplitude, threshold_current and threshold_percentage are all set in iv." + " Will only keep amplitude value." + ) + return Step.amplitude_based( + self.dt, + pre_delay=pre_delay, + duration=duration, + post_delay=post_delay, + amplitude=amplitude, + ) + + if threshold_current is not None and threshold_current != 0 and threshold_percentage is not None: + return Step.threshold_based( + self.dt, + pre_delay=pre_delay, + duration=duration, + post_delay=post_delay, + threshold_current=threshold_current, + threshold_percentage=threshold_percentage, + ) + + raise TypeError("You have to give either threshold_current or amplitude") def fire_pattern( self, - threshold_current: float, - threshold_percentage: float = 200.0, + threshold_current: Optional[float] = None, + threshold_percentage: Optional[float] = 200.0, + amplitude: Optional[float] = None, ) -> Stimulus: """Returns the FirePattern Stimulus object, a type of Step stimulus. Args: threshold_current: The threshold current of the Cell. threshold_percentage: Percentage of desired threshold_current amplification. + amplitude: Raw amplitude of input current. """ pre_delay = 250.0 duration = 3600.0 post_delay = 250.0 - return Step.threshold_based( - self.dt, - pre_delay=pre_delay, - duration=duration, - post_delay=post_delay, - threshold_current=threshold_current, - threshold_percentage=threshold_percentage, - ) + + if amplitude is not None: + if threshold_current is not None and threshold_current != 0 and threshold_percentage is not None: + logger.info( + "amplitude, threshold_current and threshold_percentage are all set in fire_pattern." + " Will only keep amplitude value." + ) + return Step.amplitude_based( + self.dt, + pre_delay=pre_delay, + duration=duration, + post_delay=post_delay, + amplitude=amplitude, + ) + + if threshold_current is not None and threshold_current != 0 and threshold_percentage is not None: + return Step.threshold_based( + self.dt, + pre_delay=pre_delay, + duration=duration, + post_delay=post_delay, + threshold_current=threshold_current, + threshold_percentage=threshold_percentage, + ) + + raise TypeError("You have to give either threshold_current or amplitude") def pos_cheops( self, - threshold_current: float, - threshold_percentage: float = 300.0, + threshold_current: Optional[float] = None, + threshold_percentage: Optional[float] = 300.0, + amplitude: Optional[float] = None, ) -> Stimulus: """A combination of pyramid shaped Ramp stimuli with a positive amplitude. @@ -399,6 +486,7 @@ def pos_cheops( Args: threshold_current: The threshold current of the Cell. threshold_percentage: Percentage of desired threshold_current amplification. + amplitude: Raw amplitude of input current. """ delay = 250.0 ramp1_duration = 4000.0 @@ -407,7 +495,15 @@ def pos_cheops( inter_delay = 2000.0 post_delay = 250.0 - amplitude = threshold_current * threshold_percentage / 100 + if amplitude is None: + if threshold_current is None or threshold_current == 0 or threshold_percentage is None: + raise TypeError("You have to give either threshold_current or amplitude") + amplitude = threshold_current * threshold_percentage / 100 + elif threshold_current is not None and threshold_current != 0 and threshold_percentage is not None: + logger.info( + "amplitude, threshold_current and threshold_percentage are all set in pos_cheops." + " Will only keep amplitude value." + ) result = ( Empty(self.dt, duration=delay) + Slope(self.dt, duration=ramp1_duration, amplitude_start=0.0, amplitude_end=amplitude) @@ -424,8 +520,9 @@ def pos_cheops( def neg_cheops( self, - threshold_current: float, - threshold_percentage: float = 300.0, + threshold_current: Optional[float] = None, + threshold_percentage: Optional[float] = 300.0, + amplitude: Optional[float] = None, ) -> Stimulus: """A combination of pyramid shaped Ramp stimuli with a negative amplitude. @@ -433,6 +530,7 @@ def neg_cheops( Args: threshold_current: The threshold current of the Cell. threshold_percentage: Percentage of desired threshold_current amplification. + amplitude: Raw amplitude of input current. """ delay = 1750.0 ramp1_duration = 3333.0 @@ -441,7 +539,15 @@ def neg_cheops( inter_delay = 2000.0 post_delay = 250.0 - amplitude = - threshold_current * threshold_percentage / 100 + if amplitude is None: + if threshold_current is None or threshold_current == 0 or threshold_percentage is None: + raise TypeError("You have to give either threshold_current or amplitude") + amplitude = - threshold_current * threshold_percentage / 100 + elif threshold_current is not None and threshold_current != 0 and threshold_percentage is not None: + logger.info( + "amplitude, threshold_current and threshold_percentage are all set in neg_cheops." + " Will only keep amplitude value." + ) result = ( Empty(self.dt, duration=delay) + Slope(self.dt, duration=ramp1_duration, amplitude_start=0.0, amplitude_end=amplitude) diff --git a/tests/test_analysis/test_inject_sequence.py b/tests/test_analysis/test_inject_sequence.py index 30267677..8b130171 100644 --- a/tests/test_analysis/test_inject_sequence.py +++ b/tests/test_analysis/test_inject_sequence.py @@ -32,6 +32,7 @@ def mock_run_stimulus(): def test_apply_multiple_step_stimuli(mock_run_stimulus): """Do not run the code in parallel, mock the return value via MockRecording.""" amplitudes = [80, 100, 120, 140] + thres_perc = [0.08] cell = create_ball_stick() with patch('bluecellulab.analysis.inject_sequence.IsolatedProcess') as mock_isolated_process, \ @@ -39,9 +40,11 @@ def test_apply_multiple_step_stimuli(mock_run_stimulus): # the mock process pool to return a list of MockRecordings mock_isolated_process.return_value.__enter__.return_value.starmap.return_value = [MockRecording() for _ in amplitudes] - recordings = apply_multiple_stimuli(cell, StimulusName.FIRE_PATTERN, amplitudes, n_processes=4) + recordings = apply_multiple_stimuli(cell, StimulusName.FIRE_PATTERN, amplitudes, threshold_based=False, n_processes=4) + recordings_thres = apply_multiple_stimuli(cell, StimulusName.FIRE_PATTERN, thres_perc, n_processes=4) assert len(recordings) == len(amplitudes) - for recording in recordings.values(): + assert len(recordings_thres) == len(thres_perc) + for recording in list(recordings.values()) + list(recordings_thres.values()): assert len(recording.time) > 0 assert len(recording.time) == len(recording.voltage) assert len(recording.time) == len(recording.current) @@ -52,7 +55,10 @@ def test_apply_multiple_step_stimuli(mock_run_stimulus): assert "Unknown stimulus name" in str(exc_info.value) short_amplitudes = [80] + short_thres = [0.08] other_stim = [StimulusName.AP_WAVEFORM, StimulusName.IV, StimulusName.IDREST, StimulusName.POS_CHEOPS, StimulusName.NEG_CHEOPS] for stim in other_stim: - res = apply_multiple_stimuli(cell, stim, short_amplitudes, n_processes=1) + res = apply_multiple_stimuli(cell, stim, short_amplitudes, threshold_based=False, n_processes=1) + res_thres = apply_multiple_stimuli(cell, stim, short_thres, n_processes=1) assert len(res) == len(short_amplitudes) + assert len(res_thres) == len(short_thres) diff --git a/tests/test_stimulus/test_factory.py b/tests/test_stimulus/test_factory.py index 6ef6f23c..9e26db63 100644 --- a/tests/test_stimulus/test_factory.py +++ b/tests/test_stimulus/test_factory.py @@ -91,11 +91,28 @@ def test_create_ap_waveform(self): assert s.current[2500] == 2.2 assert s.current[-1] == 0.0 + s = self.factory.ap_waveform(amplitude=0.1) + assert isinstance(s, CombinedStimulus) + + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.ap_waveform(threshold_current=0.0) + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.ap_waveform() + def test_create_idrest(self): s = self.factory.idrest(threshold_current=1) assert isinstance(s, CombinedStimulus) assert s.stimulus_time == 1850 + s = self.factory.idrest(amplitude=0.1) + assert isinstance(s, CombinedStimulus) + assert s.stimulus_time == 1850 + + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.idrest(threshold_current=0.0) + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.idrest() + def test_create_iv(self): s = self.factory.iv(threshold_current=1) assert isinstance(s, CombinedStimulus) @@ -105,21 +122,57 @@ def test_create_iv(self): # assert no positive values assert not np.any(s.current > 0) + s = self.factory.iv(amplitude=0.1) + assert isinstance(s, CombinedStimulus) + assert s.stimulus_time == 3500 + + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.iv(threshold_current=0.0) + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.iv() + def test_create_fire_pattern(self): s = self.factory.fire_pattern(threshold_current=1) assert isinstance(s, CombinedStimulus) assert s.stimulus_time == 4100 + s = self.factory.fire_pattern(amplitude=0.1) + assert isinstance(s, CombinedStimulus) + assert s.stimulus_time == 4100 + + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.fire_pattern(threshold_current=0.0) + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.fire_pattern() + def test_create_pos_cheops(self): s = self.factory.pos_cheops(threshold_current=1) assert isinstance(s, CombinedStimulus) assert s.stimulus_time == 19166.0 + s = self.factory.pos_cheops(amplitude=0.1) + assert isinstance(s, CombinedStimulus) + assert s.stimulus_time == 19166.0 + + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.pos_cheops(threshold_current=0.0) + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.pos_cheops() + def test_create_neg_cheops(self): s = self.factory.neg_cheops(threshold_current=1) assert isinstance(s, CombinedStimulus) assert s.stimulus_time == 18220.0 + s = self.factory.neg_cheops(amplitude=0.1) + assert isinstance(s, CombinedStimulus) + assert s.stimulus_time == 18220.0 + + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.neg_cheops(threshold_current=0.0) + with pytest.raises(TypeError, match="You have to give either threshold_current or amplitude"): + self.factory.neg_cheops() + def test_combined_stimulus(): """Test combining Stimulus objects."""