From 27e88044485d4f9ce7dcddfa0eb3a00155820dfd Mon Sep 17 00:00:00 2001 From: ilkilic Date: Tue, 27 Aug 2024 17:28:46 +0200 Subject: [PATCH 1/5] fix protocol name in compute_scores when using threshold based optimisation --- bluepyemodel/validation/validation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bluepyemodel/validation/validation.py b/bluepyemodel/validation/validation.py index 28cf8286..05859311 100644 --- a/bluepyemodel/validation/validation.py +++ b/bluepyemodel/validation/validation.py @@ -76,7 +76,11 @@ def compute_scores(model, validation_protocols): scores = model.evaluator.fitness_calculator.calculate_scores(model.responses) for feature_name in scores: - protocol_name = feature_name.split(".")[0] + n = feature_name.split(".") + # case where protocol has '.' in its name, e.g. IV_40.0 + if n[1].isdigit(): + n = [".".join(n[:2]), n[2], n[3]] + protocol_name = n[0] if any(are_same_protocol(p, protocol_name) for p in validation_protocols): model.scores_validation[feature_name] = scores[feature_name] else: From a4f74cd189087cea99dc512a1006555ef7162874 Mon Sep 17 00:00:00 2001 From: ilkilic Date: Wed, 28 Aug 2024 10:00:02 +0200 Subject: [PATCH 2/5] minor fix --- bluepyemodel/validation/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bluepyemodel/validation/validation.py b/bluepyemodel/validation/validation.py index 05859311..6d2af34b 100644 --- a/bluepyemodel/validation/validation.py +++ b/bluepyemodel/validation/validation.py @@ -79,7 +79,7 @@ def compute_scores(model, validation_protocols): n = feature_name.split(".") # case where protocol has '.' in its name, e.g. IV_40.0 if n[1].isdigit(): - n = [".".join(n[:2]), n[2], n[3]] + n = [".".join(n[:2])] + n[2:] protocol_name = n[0] if any(are_same_protocol(p, protocol_name) for p in validation_protocols): model.scores_validation[feature_name] = scores[feature_name] From d8de5677a2003cbf139484b0a5da0602fafe6b7e Mon Sep 17 00:00:00 2001 From: ilkilic Date: Wed, 28 Aug 2024 14:12:41 +0200 Subject: [PATCH 3/5] refactoring --- .../emodel_pipeline/plotting_utils.py | 22 +++----- bluepyemodel/tools/utils.py | 56 +++++++++++++++++-- bluepyemodel/validation/validation.py | 7 +-- tests/unit_tests/test_tools.py | 38 +++++++++++++ 4 files changed, 98 insertions(+), 25 deletions(-) diff --git a/bluepyemodel/emodel_pipeline/plotting_utils.py b/bluepyemodel/emodel_pipeline/plotting_utils.py index dc473c41..e213cda1 100644 --- a/bluepyemodel/emodel_pipeline/plotting_utils.py +++ b/bluepyemodel/emodel_pipeline/plotting_utils.py @@ -32,6 +32,9 @@ from bluepyemodel.evaluation.protocols import ThresholdBasedProtocol from bluepyemodel.evaluation.recordings import FixedDtRecordingCustom from bluepyemodel.evaluation.recordings import FixedDtRecordingStimulus +from bluepyemodel.tools.utils import get_curr_name +from bluepyemodel.tools.utils import get_loc_name +from bluepyemodel.tools.utils import get_protocol_name logger = logging.getLogger("__main__") @@ -379,12 +382,7 @@ def get_simulated_FI_curve_for_plotting(evaluator, responses, prot_name): simulated_amp = [] for val in values: if prot_name.lower() in val.lower(): - # val is e.g. IV_40.soma.maximum_voltage_from_voltagebase - n = val.split(".") - # case where protocol has '.' in its name, e.g. IV_40.0 - if len(n) == 4 and n[1].isdigit(): - n = [".".join(n[:2]), n[2], n[3]] - protocol_name = n[0] + protocol_name = get_protocol_name(val) amp_temp = float(protocol_name.split("_")[-1]) if "mean_frequency" in val: simulated_freq.append(values[val]) @@ -593,17 +591,11 @@ def get_ordered_currentscape_keys(keys): ordered_keys = {} for name in keys: - n = name.split(".") - # case where protocol has '.' in its name, e.g. IV_-100.0 - if len(n) == 4 and n[1].isdigit(): - n = [".".join(n[:2]), n[2], n[3]] - prot_name = n[0] + prot_name = get_protocol_name(name) # prot_name can be e.g. RMPProtocol, or RMPProtocol_apical055 if not any(to_skip_ in prot_name for to_skip_ in to_skip): - if len(n) != 3: - raise ValueError(f"Expected 3 elements in {n}") - loc_name = n[1] - curr_name = n[2] + loc_name = get_loc_name(name) + curr_name = get_curr_name(name) if prot_name not in ordered_keys: ordered_keys[prot_name] = {} diff --git a/bluepyemodel/tools/utils.py b/bluepyemodel/tools/utils.py index bef41fc7..bc962819 100644 --- a/bluepyemodel/tools/utils.py +++ b/bluepyemodel/tools/utils.py @@ -256,10 +256,56 @@ def get_amplitude_from_feature_key(feat_key): Args: feat_key (str): feature key, e.g. IV_40.soma.maximum_voltage_from_voltagebase """ - n = feat_key.split(".") - # case where protocol has '.' in its name, e.g. IV_40.0 - if len(n) == 4 and n[1].isdigit(): - n = [".".join(n[:2]), n[2], n[3]] - protocol_name = n[0] + protocol_name = get_protocol_name(feat_key) return float(protocol_name.split("_")[-1]) + + +def combine_parts_if_dot_in_protocol(feature_name): + """ + Combine the first two elements of a list if the second element is numeric, + indicating the presence of a dot in the protocol. + + Args: + feature_name (list): The list of split parts from the feature name. + """ + if len(feature_name) > 1 and feature_name[1].isdigit(): + return [".".join(feature_name[:2])] + feature_name[2:] + return feature_name + + +def get_protocol_name(feature_name): + """ + Extract the protocol name from the feature name. + + Args: + feature_name (str): The full feature name string. + """ + n = combine_parts_if_dot_in_protocol(feature_name.split(".")) + return n[0] + + +def get_loc_name(feature_name): + """ + Extract the location name from the feature name. + + Args: + feature_name (str): The full feature name string. + """ + n = combine_parts_if_dot_in_protocol(feature_name.split(".")) + if len(n) < 2: + raise IndexError("cannot get location name from feature name") + return n[1] + + +def get_curr_name(feature_name): + """ + Extract the current name from the feature name. + + Args: + feature_name (str): The full feature name string. + """ + n = combine_parts_if_dot_in_protocol(feature_name.split(".")) + if len(n) < 3: + raise IndexError("cannot get current name from feature name") + return n[2] diff --git a/bluepyemodel/validation/validation.py b/bluepyemodel/validation/validation.py index 6d2af34b..0d217c45 100644 --- a/bluepyemodel/validation/validation.py +++ b/bluepyemodel/validation/validation.py @@ -25,6 +25,7 @@ from bluepyemodel.evaluation.evaluation import compute_responses from bluepyemodel.evaluation.evaluation import get_evaluator_from_access_point from bluepyemodel.tools.utils import are_same_protocol +from bluepyemodel.tools.utils import get_protocol_name from bluepyemodel.validation import validation_functions logger = logging.getLogger(__name__) @@ -76,11 +77,7 @@ def compute_scores(model, validation_protocols): scores = model.evaluator.fitness_calculator.calculate_scores(model.responses) for feature_name in scores: - n = feature_name.split(".") - # case where protocol has '.' in its name, e.g. IV_40.0 - if n[1].isdigit(): - n = [".".join(n[:2])] + n[2:] - protocol_name = n[0] + protocol_name = get_protocol_name(feature_name) if any(are_same_protocol(p, protocol_name) for p in validation_protocols): model.scores_validation[feature_name] = scores[feature_name] else: diff --git a/tests/unit_tests/test_tools.py b/tests/unit_tests/test_tools.py index 85e1cb84..f565a270 100644 --- a/tests/unit_tests/test_tools.py +++ b/tests/unit_tests/test_tools.py @@ -20,6 +20,9 @@ from bluepyemodel.tools.utils import are_same_protocol from bluepyemodel.tools.utils import format_protocol_name_to_list from bluepyemodel.tools.utils import select_rec_for_thumbnail +from bluepyemodel.tools.utils import get_protocol_name +from bluepyemodel.tools.utils import get_loc_name +from bluepyemodel.tools.utils import get_curr_name from tests.utils import DATA @@ -136,3 +139,38 @@ def test_select_rec_for_thumbnail(): assert ( select_rec_for_thumbnail(rec_names, thumbnail_rec="sAHP_20.soma.v") == "IDrest_130.soma.v" ) + + +def test_get_protocol_name(): + feature_name = "IV_40.0.soma.v.voltage_base" + assert get_protocol_name(feature_name) == "IV_40.0" + + feature_name = "IV_40.soma.v.voltage_base" + assert get_protocol_name(feature_name) == "IV_40" + + feature_name = "ProtocolA.1.soma.some_feature" + assert get_protocol_name(feature_name) == "ProtocolA.1" + + +def test_get_loc_name(): + feature_name = "IV_40.0.soma.v.voltage_base" + assert get_loc_name(feature_name) == "soma" + + feature_name = "IV_40.soma.v.voltage_base" + assert get_loc_name(feature_name) == "soma" + + feature_name = "IV_40.0" + with pytest.raises(IndexError, match="cannot get location name from feature name"): + get_loc_name(feature_name) + + +def test_get_curr_name(): + feature_name = "IV_40.0.soma.v.voltage_base" + assert get_curr_name(feature_name) == "v" + + feature_name = "IV_40.soma.v.voltage_base" + assert get_curr_name(feature_name) == "v" + + feature_name = "IV_40.0.soma" + with pytest.raises(IndexError, match="cannot get current name from feature name"): + get_curr_name(feature_name) \ No newline at end of file From ce08dca6ba30b727d9b838e44aeda59b8bc8fe85 Mon Sep 17 00:00:00 2001 From: ilkilic Date: Wed, 28 Aug 2024 16:06:44 +0200 Subject: [PATCH 4/5] minor updates --- bluepyemodel/tools/utils.py | 108 ++++++++++++++++++++++++++------- tests/unit_tests/test_tools.py | 40 ++++++++++-- 2 files changed, 121 insertions(+), 27 deletions(-) diff --git a/bluepyemodel/tools/utils.py b/bluepyemodel/tools/utils.py index bc962819..771929fe 100644 --- a/bluepyemodel/tools/utils.py +++ b/bluepyemodel/tools/utils.py @@ -261,51 +261,113 @@ def get_amplitude_from_feature_key(feat_key): return float(protocol_name.split("_")[-1]) -def combine_parts_if_dot_in_protocol(feature_name): +def parse_feature_name_parts(feature_name): """ - Combine the first two elements of a list if the second element is numeric, - indicating the presence of a dot in the protocol. + Splits the feature name into its respective parts, handling cases where the protocol name contains a dot. + + This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base") + and a response key (e.g., "IV_40.0.soma.v"). It splits the input into a list of parts, combining the + first two parts if the protocol name contains a dot and is followed by a numeric component. Args: - feature_name (list): The list of split parts from the feature name. + feature_name (str): The full feature name string or response key to be parsed. + + Returns: + list: A list of strings representing the correctly parsed parts of the feature name. + + Examples: + >>> parse_feature_name_parts("IV_40.0.soma.v.voltage_base") + ['IV_40.0', 'soma', 'v', 'voltage_base'] + + >>> parse_feature_name_parts("IV_40.0.soma.v") + ['IV_40.0', 'soma', 'v'] """ - if len(feature_name) > 1 and feature_name[1].isdigit(): - return [".".join(feature_name[:2])] + feature_name[2:] - return feature_name + parts = feature_name.split(".") + if len(parts) > 1 and parts[1].isdigit(): + return [".".join(parts[:2])] + parts[2:] + return parts def get_protocol_name(feature_name): """ - Extract the protocol name from the feature name. + Extracts the protocol name from the feature name or response key. + + This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base") + and a response key (e.g., "IV_40.0.soma.v"). It returns the first part of the input, which is + the protocol name, correctly handling cases where the protocol contains a dot. Args: - feature_name (str): The full feature name string. + feature_name (str): The full feature name string or response key. + + Returns: + str: The protocol name part of the feature name. + + Examples: + >>> get_protocol_name("IV_40.0.soma.v.voltage_base") + 'IV_40.0' + + >>> get_protocol_name("IV_40.0.soma.v") + 'IV_40.0' """ - n = combine_parts_if_dot_in_protocol(feature_name.split(".")) - return n[0] + return parse_feature_name_parts(feature_name)[0] def get_loc_name(feature_name): """ - Extract the location name from the feature name. + Extracts the location name from the feature name or response key. + + This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base") + and a response key (e.g., "IV_40.0.soma.v"). It returns the second part of the input, which is + the location name, correctly handling cases where the protocol contains a dot. Args: - feature_name (str): The full feature name string. + feature_name (str): The full feature name string or response key. + + Returns: + str: The location name part of the feature name. + + Raises: + IndexError: If the location name cannot be determined from the input. + + Examples: + >>> get_loc_name("IV_40.0.soma.v.voltage_base") + 'soma' + + >>> get_loc_name("IV_40.0.soma.v") + 'soma' """ - n = combine_parts_if_dot_in_protocol(feature_name.split(".")) - if len(n) < 2: - raise IndexError("cannot get location name from feature name") - return n[1] + parts = parse_feature_name_parts(feature_name) + if len(parts) < 2: + raise IndexError("Location name not found in the feature name.") + return parts[1] def get_curr_name(feature_name): """ - Extract the current name from the feature name. + Extracts the current name from the feature name or response key. + + This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base") + and a response key (e.g., "IV_40.0.soma.v"). It returns the third part of the input, which is + the current name, correctly handling cases where the protocol contains a dot. Args: - feature_name (str): The full feature name string. + feature_name (str): The full feature name string or response key. + + Returns: + str: The current name part of the feature name. + + Raises: + IndexError: If the current name cannot be determined from the input. + + Examples: + >>> get_curr_name("IV_40.0.soma.v.voltage_base") + 'v' + + >>> get_curr_name("IV_40.0.soma.v") + 'v' """ - n = combine_parts_if_dot_in_protocol(feature_name.split(".")) - if len(n) < 3: - raise IndexError("cannot get current name from feature name") - return n[2] + parts = parse_feature_name_parts(feature_name) + if len(parts) < 3: + raise IndexError("Current name not found in the feature name.") + return parts[2] + diff --git a/tests/unit_tests/test_tools.py b/tests/unit_tests/test_tools.py index f565a270..7c14dab9 100644 --- a/tests/unit_tests/test_tools.py +++ b/tests/unit_tests/test_tools.py @@ -142,17 +142,29 @@ def test_select_rec_for_thumbnail(): def test_get_protocol_name(): + # feature keys feature_name = "IV_40.0.soma.v.voltage_base" assert get_protocol_name(feature_name) == "IV_40.0" feature_name = "IV_40.soma.v.voltage_base" assert get_protocol_name(feature_name) == "IV_40" - feature_name = "ProtocolA.1.soma.some_feature" + feature_name = "ProtocolA.1.soma.v.some_feature" + assert get_protocol_name(feature_name) == "ProtocolA.1" + + # response keys + feature_name = "IV_40.0.soma.v" + assert get_protocol_name(feature_name) == "IV_40.0" + + feature_name = "IV_40.soma.v" + assert get_protocol_name(feature_name) == "IV_40" + + feature_name = "ProtocolA.1.soma.v" assert get_protocol_name(feature_name) == "ProtocolA.1" def test_get_loc_name(): + # feature keys feature_name = "IV_40.0.soma.v.voltage_base" assert get_loc_name(feature_name) == "soma" @@ -160,11 +172,21 @@ def test_get_loc_name(): assert get_loc_name(feature_name) == "soma" feature_name = "IV_40.0" - with pytest.raises(IndexError, match="cannot get location name from feature name"): + with pytest.raises(IndexError, match="Location name not found in the feature name."): get_loc_name(feature_name) + # response keys + feature_name = "IV_40.0.soma.v" + assert get_loc_name(feature_name) == "soma" + + feature_name = "IV_40.soma.v" + assert get_loc_name(feature_name) == "soma" + + feature_name = "ProtocolA.1.soma.v" + assert get_loc_name(feature_name) == "soma" def test_get_curr_name(): + # feature keys feature_name = "IV_40.0.soma.v.voltage_base" assert get_curr_name(feature_name) == "v" @@ -172,5 +194,15 @@ def test_get_curr_name(): assert get_curr_name(feature_name) == "v" feature_name = "IV_40.0.soma" - with pytest.raises(IndexError, match="cannot get current name from feature name"): - get_curr_name(feature_name) \ No newline at end of file + with pytest.raises(IndexError, match="Current name not found in the feature name."): + get_curr_name(feature_name) + + # response keys + feature_name = "IV_40.0.soma.v" + assert get_curr_name(feature_name) == "v" + + feature_name = "IV_40.soma.v" + assert get_curr_name(feature_name) == "v" + + feature_name = "ProtocolA.1.soma.v" + assert get_curr_name(feature_name) == "v" \ No newline at end of file From e52d1f3a405c3308ad5acb16549c564cb079cabc Mon Sep 17 00:00:00 2001 From: ilkilic Date: Wed, 28 Aug 2024 16:14:42 +0200 Subject: [PATCH 5/5] lint fix --- bluepyemodel/tools/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/bluepyemodel/tools/utils.py b/bluepyemodel/tools/utils.py index 771929fe..73610b18 100644 --- a/bluepyemodel/tools/utils.py +++ b/bluepyemodel/tools/utils.py @@ -263,11 +263,13 @@ def get_amplitude_from_feature_key(feat_key): def parse_feature_name_parts(feature_name): """ - Splits the feature name into its respective parts, handling cases where the protocol name contains a dot. + Splits the feature name into its respective parts, + handling cases where the protocol name contains a dot. This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base") - and a response key (e.g., "IV_40.0.soma.v"). It splits the input into a list of parts, combining the - first two parts if the protocol name contains a dot and is followed by a numeric component. + and a response key (e.g., "IV_40.0.soma.v"). It splits the input into a list of parts, + combining the first two parts if the protocol name contains a dot and is followed by + a numeric component. Args: feature_name (str): The full feature name string or response key to be parsed. @@ -370,4 +372,3 @@ def get_curr_name(feature_name): if len(parts) < 3: raise IndexError("Current name not found in the feature name.") return parts[2] -