Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compute_scores to handle protocol names with '.' #166

Merged
merged 5 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions bluepyemodel/emodel_pipeline/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from bluepyemodel.evaluation.protocols import ThresholdBasedProtocol
from bluepyemodel.evaluation.recordings import FixedDtRecordingCustom
from bluepyemodel.evaluation.recordings import FixedDtRecordingStimulus
from bluepyemodel.tools.utils import get_curr_name
from bluepyemodel.tools.utils import get_loc_name
from bluepyemodel.tools.utils import get_protocol_name

logger = logging.getLogger("__main__")

Expand Down Expand Up @@ -379,12 +382,7 @@ def get_simulated_FI_curve_for_plotting(evaluator, responses, prot_name):
simulated_amp = []
for val in values:
if prot_name.lower() in val.lower():
# val is e.g. IV_40.soma.maximum_voltage_from_voltagebase
n = val.split(".")
# case where protocol has '.' in its name, e.g. IV_40.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
protocol_name = n[0]
protocol_name = get_protocol_name(val)
amp_temp = float(protocol_name.split("_")[-1])
if "mean_frequency" in val:
simulated_freq.append(values[val])
Expand Down Expand Up @@ -593,17 +591,11 @@ def get_ordered_currentscape_keys(keys):

ordered_keys = {}
for name in keys:
n = name.split(".")
# case where protocol has '.' in its name, e.g. IV_-100.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
prot_name = n[0]
prot_name = get_protocol_name(name)
# prot_name can be e.g. RMPProtocol, or RMPProtocol_apical055
if not any(to_skip_ in prot_name for to_skip_ in to_skip):
if len(n) != 3:
raise ValueError(f"Expected 3 elements in {n}")
loc_name = n[1]
curr_name = n[2]
loc_name = get_loc_name(name)
curr_name = get_curr_name(name)

if prot_name not in ordered_keys:
ordered_keys[prot_name] = {}
Expand Down
56 changes: 51 additions & 5 deletions bluepyemodel/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,56 @@ def get_amplitude_from_feature_key(feat_key):
Args:
feat_key (str): feature key, e.g. IV_40.soma.maximum_voltage_from_voltagebase
"""
n = feat_key.split(".")
# case where protocol has '.' in its name, e.g. IV_40.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
protocol_name = n[0]
protocol_name = get_protocol_name(feat_key)

return float(protocol_name.split("_")[-1])


def combine_parts_if_dot_in_protocol(feature_name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you say for this function and get_protocol_name and get_loc_name and get_curr_name in their docstrings that it works both with a feature str (e.g. "IV_40.0.soma.v.voltage_base") and for a response key (e.g.""IV_40.0.soma.v"). Also adding examples of expected names would be a +

"""
Combine the first two elements of a list if the second element is numeric,
indicating the presence of a dot in the protocol.

Args:
feature_name (list): The list of split parts from the feature name.
"""
if len(feature_name) > 1 and feature_name[1].isdigit():
return [".".join(feature_name[:2])] + feature_name[2:]
return feature_name


def get_protocol_name(feature_name):
"""
Extract the protocol name from the feature name.

Args:
feature_name (str): The full feature name string.
"""
n = combine_parts_if_dot_in_protocol(feature_name.split("."))
return n[0]


def get_loc_name(feature_name):
"""
Extract the location name from the feature name.

Args:
feature_name (str): The full feature name string.
"""
n = combine_parts_if_dot_in_protocol(feature_name.split("."))
if len(n) < 2:
raise IndexError("cannot get location name from feature name")
return n[1]


def get_curr_name(feature_name):
"""
Extract the current name from the feature name.

Args:
feature_name (str): The full feature name string.
"""
n = combine_parts_if_dot_in_protocol(feature_name.split("."))
if len(n) < 3:
raise IndexError("cannot get current name from feature name")
return n[2]
7 changes: 2 additions & 5 deletions bluepyemodel/validation/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from bluepyemodel.evaluation.evaluation import compute_responses
from bluepyemodel.evaluation.evaluation import get_evaluator_from_access_point
from bluepyemodel.tools.utils import are_same_protocol
from bluepyemodel.tools.utils import get_protocol_name
from bluepyemodel.validation import validation_functions

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,11 +77,7 @@ def compute_scores(model, validation_protocols):

scores = model.evaluator.fitness_calculator.calculate_scores(model.responses)
for feature_name in scores:
n = feature_name.split(".")
# case where protocol has '.' in its name, e.g. IV_40.0
if n[1].isdigit():
n = [".".join(n[:2])] + n[2:]
protocol_name = n[0]
protocol_name = get_protocol_name(feature_name)
if any(are_same_protocol(p, protocol_name) for p in validation_protocols):
model.scores_validation[feature_name] = scores[feature_name]
else:
Expand Down
38 changes: 38 additions & 0 deletions tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from bluepyemodel.tools.utils import are_same_protocol
from bluepyemodel.tools.utils import format_protocol_name_to_list
from bluepyemodel.tools.utils import select_rec_for_thumbnail
from bluepyemodel.tools.utils import get_protocol_name
from bluepyemodel.tools.utils import get_loc_name
from bluepyemodel.tools.utils import get_curr_name
from tests.utils import DATA


Expand Down Expand Up @@ -136,3 +139,38 @@ def test_select_rec_for_thumbnail():
assert (
select_rec_for_thumbnail(rec_names, thumbnail_rec="sAHP_20.soma.v") == "IDrest_130.soma.v"
)


def test_get_protocol_name():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tests could you also try with strings from response keys (i.e. without the feature at the end, e.g. 'IV_40.0.soma.v')

feature_name = "IV_40.0.soma.v.voltage_base"
assert get_protocol_name(feature_name) == "IV_40.0"

feature_name = "IV_40.soma.v.voltage_base"
assert get_protocol_name(feature_name) == "IV_40"

feature_name = "ProtocolA.1.soma.some_feature"
assert get_protocol_name(feature_name) == "ProtocolA.1"


def test_get_loc_name():
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_loc_name(feature_name) == "soma"

feature_name = "IV_40.soma.v.voltage_base"
assert get_loc_name(feature_name) == "soma"

feature_name = "IV_40.0"
with pytest.raises(IndexError, match="cannot get location name from feature name"):
get_loc_name(feature_name)


def test_get_curr_name():
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_curr_name(feature_name) == "v"

feature_name = "IV_40.soma.v.voltage_base"
assert get_curr_name(feature_name) == "v"

feature_name = "IV_40.0.soma"
with pytest.raises(IndexError, match="cannot get current name from feature name"):
get_curr_name(feature_name)
Loading