Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compute_scores to handle protocol names with '.' #166

Merged
merged 5 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions bluepyemodel/emodel_pipeline/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from bluepyemodel.evaluation.protocols import ThresholdBasedProtocol
from bluepyemodel.evaluation.recordings import FixedDtRecordingCustom
from bluepyemodel.evaluation.recordings import FixedDtRecordingStimulus
from bluepyemodel.tools.utils import get_curr_name
from bluepyemodel.tools.utils import get_loc_name
from bluepyemodel.tools.utils import get_protocol_name

logger = logging.getLogger("__main__")

Expand Down Expand Up @@ -379,12 +382,7 @@ def get_simulated_FI_curve_for_plotting(evaluator, responses, prot_name):
simulated_amp = []
for val in values:
if prot_name.lower() in val.lower():
# val is e.g. IV_40.soma.maximum_voltage_from_voltagebase
n = val.split(".")
# case where protocol has '.' in its name, e.g. IV_40.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
protocol_name = n[0]
protocol_name = get_protocol_name(val)
amp_temp = float(protocol_name.split("_")[-1])
if "mean_frequency" in val:
simulated_freq.append(values[val])
Expand Down Expand Up @@ -593,17 +591,11 @@ def get_ordered_currentscape_keys(keys):

ordered_keys = {}
for name in keys:
n = name.split(".")
# case where protocol has '.' in its name, e.g. IV_-100.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
prot_name = n[0]
prot_name = get_protocol_name(name)
# prot_name can be e.g. RMPProtocol, or RMPProtocol_apical055
if not any(to_skip_ in prot_name for to_skip_ in to_skip):
if len(n) != 3:
raise ValueError(f"Expected 3 elements in {n}")
loc_name = n[1]
curr_name = n[2]
loc_name = get_loc_name(name)
curr_name = get_curr_name(name)

if prot_name not in ordered_keys:
ordered_keys[prot_name] = {}
Expand Down
119 changes: 114 additions & 5 deletions bluepyemodel/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,119 @@ def get_amplitude_from_feature_key(feat_key):
Args:
feat_key (str): feature key, e.g. IV_40.soma.maximum_voltage_from_voltagebase
"""
n = feat_key.split(".")
# case where protocol has '.' in its name, e.g. IV_40.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
protocol_name = n[0]
protocol_name = get_protocol_name(feat_key)

return float(protocol_name.split("_")[-1])


def parse_feature_name_parts(feature_name):
"""
Splits the feature name into its respective parts,
handling cases where the protocol name contains a dot.

This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base")
and a response key (e.g., "IV_40.0.soma.v"). It splits the input into a list of parts,
combining the first two parts if the protocol name contains a dot and is followed by
a numeric component.

Args:
feature_name (str): The full feature name string or response key to be parsed.

Returns:
list: A list of strings representing the correctly parsed parts of the feature name.

Examples:
>>> parse_feature_name_parts("IV_40.0.soma.v.voltage_base")
['IV_40.0', 'soma', 'v', 'voltage_base']

>>> parse_feature_name_parts("IV_40.0.soma.v")
['IV_40.0', 'soma', 'v']
"""
parts = feature_name.split(".")
if len(parts) > 1 and parts[1].isdigit():
return [".".join(parts[:2])] + parts[2:]
return parts


def get_protocol_name(feature_name):
"""
Extracts the protocol name from the feature name or response key.

This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base")
and a response key (e.g., "IV_40.0.soma.v"). It returns the first part of the input, which is
the protocol name, correctly handling cases where the protocol contains a dot.

Args:
feature_name (str): The full feature name string or response key.

Returns:
str: The protocol name part of the feature name.

Examples:
>>> get_protocol_name("IV_40.0.soma.v.voltage_base")
'IV_40.0'

>>> get_protocol_name("IV_40.0.soma.v")
'IV_40.0'
"""
return parse_feature_name_parts(feature_name)[0]


def get_loc_name(feature_name):
"""
Extracts the location name from the feature name or response key.

This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base")
and a response key (e.g., "IV_40.0.soma.v"). It returns the second part of the input, which is
the location name, correctly handling cases where the protocol contains a dot.

Args:
feature_name (str): The full feature name string or response key.

Returns:
str: The location name part of the feature name.

Raises:
IndexError: If the location name cannot be determined from the input.

Examples:
>>> get_loc_name("IV_40.0.soma.v.voltage_base")
'soma'

>>> get_loc_name("IV_40.0.soma.v")
'soma'
"""
parts = parse_feature_name_parts(feature_name)
if len(parts) < 2:
raise IndexError("Location name not found in the feature name.")
return parts[1]


def get_curr_name(feature_name):
"""
Extracts the current name from the feature name or response key.

This function works with both a full feature name string (e.g., "IV_40.0.soma.v.voltage_base")
and a response key (e.g., "IV_40.0.soma.v"). It returns the third part of the input, which is
the current name, correctly handling cases where the protocol contains a dot.

Args:
feature_name (str): The full feature name string or response key.

Returns:
str: The current name part of the feature name.

Raises:
IndexError: If the current name cannot be determined from the input.

Examples:
>>> get_curr_name("IV_40.0.soma.v.voltage_base")
'v'

>>> get_curr_name("IV_40.0.soma.v")
'v'
"""
parts = parse_feature_name_parts(feature_name)
if len(parts) < 3:
raise IndexError("Current name not found in the feature name.")
return parts[2]
3 changes: 2 additions & 1 deletion bluepyemodel/validation/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from bluepyemodel.evaluation.evaluation import compute_responses
from bluepyemodel.evaluation.evaluation import get_evaluator_from_access_point
from bluepyemodel.tools.utils import are_same_protocol
from bluepyemodel.tools.utils import get_protocol_name
from bluepyemodel.validation import validation_functions

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,7 +77,7 @@ def compute_scores(model, validation_protocols):

scores = model.evaluator.fitness_calculator.calculate_scores(model.responses)
for feature_name in scores:
protocol_name = feature_name.split(".")[0]
protocol_name = get_protocol_name(feature_name)
if any(are_same_protocol(p, protocol_name) for p in validation_protocols):
model.scores_validation[feature_name] = scores[feature_name]
else:
Expand Down
70 changes: 70 additions & 0 deletions tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from bluepyemodel.tools.utils import are_same_protocol
from bluepyemodel.tools.utils import format_protocol_name_to_list
from bluepyemodel.tools.utils import select_rec_for_thumbnail
from bluepyemodel.tools.utils import get_protocol_name
from bluepyemodel.tools.utils import get_loc_name
from bluepyemodel.tools.utils import get_curr_name
from tests.utils import DATA


Expand Down Expand Up @@ -136,3 +139,70 @@ def test_select_rec_for_thumbnail():
assert (
select_rec_for_thumbnail(rec_names, thumbnail_rec="sAHP_20.soma.v") == "IDrest_130.soma.v"
)


def test_get_protocol_name():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tests could you also try with strings from response keys (i.e. without the feature at the end, e.g. 'IV_40.0.soma.v')

# feature keys
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_protocol_name(feature_name) == "IV_40.0"

feature_name = "IV_40.soma.v.voltage_base"
assert get_protocol_name(feature_name) == "IV_40"

feature_name = "ProtocolA.1.soma.v.some_feature"
assert get_protocol_name(feature_name) == "ProtocolA.1"

# response keys
feature_name = "IV_40.0.soma.v"
assert get_protocol_name(feature_name) == "IV_40.0"

feature_name = "IV_40.soma.v"
assert get_protocol_name(feature_name) == "IV_40"

feature_name = "ProtocolA.1.soma.v"
assert get_protocol_name(feature_name) == "ProtocolA.1"


def test_get_loc_name():
# feature keys
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_loc_name(feature_name) == "soma"

feature_name = "IV_40.soma.v.voltage_base"
assert get_loc_name(feature_name) == "soma"

feature_name = "IV_40.0"
with pytest.raises(IndexError, match="Location name not found in the feature name."):
get_loc_name(feature_name)

# response keys
feature_name = "IV_40.0.soma.v"
assert get_loc_name(feature_name) == "soma"

feature_name = "IV_40.soma.v"
assert get_loc_name(feature_name) == "soma"

feature_name = "ProtocolA.1.soma.v"
assert get_loc_name(feature_name) == "soma"

def test_get_curr_name():
# feature keys
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_curr_name(feature_name) == "v"

feature_name = "IV_40.soma.v.voltage_base"
assert get_curr_name(feature_name) == "v"

feature_name = "IV_40.0.soma"
with pytest.raises(IndexError, match="Current name not found in the feature name."):
get_curr_name(feature_name)

# response keys
feature_name = "IV_40.0.soma.v"
assert get_curr_name(feature_name) == "v"

feature_name = "IV_40.soma.v"
assert get_curr_name(feature_name) == "v"

feature_name = "ProtocolA.1.soma.v"
assert get_curr_name(feature_name) == "v"
Loading