Skip to content

Commit

Permalink
refactor: use bin search to find potential modifications instead of d…
Browse files Browse the repository at this point in the history
…ouble lookup
  • Loading branch information
jaspervdh committed Sep 2, 2024
1 parent 3df1e54 commit db285ed
Showing 1 changed file with 68 additions and 30 deletions.
98 changes: 68 additions & 30 deletions mumble/mumble.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ def __init__(
else:
self.protein_level_check = False
self.name_to_mass_residue_dict = self._get_name_to_mass_residue_dict()
self.rounded_mass_to_name_dict = self._get_rounded_mass_to_name_dict()
self.aa_sub_dict = self._get_aa_sub_dict()
self.mass_error = mass_error
self.fasta_file = IndexedFASTA(fasta_file, label=r"^[\n]?>([\S]*)") if fasta_file else None
Expand Down Expand Up @@ -429,22 +428,6 @@ def _get_name_to_mass_residue_dict(self):
.itertuples()
} # TODO: used named tuple here

def _get_rounded_mass_to_name_dict(self):
"""
Get dictionary with rounded mass as key and name as value
return:
dict: Dictionary with rounded mass as key and name as value
"""

return {
row.rounded_mass: row.name
for row in self.modification_df.groupby("rounded_mass")
.agg({"name": set})
.reset_index()
.itertuples()
}

def get_localisation(
self, psm, modification_name, residue_list, restrictions
) -> list[namedtuple]:
Expand Down Expand Up @@ -524,30 +507,85 @@ def localize_mass_shift(self, psm) -> list[namedtuple]:

# get all potential modifications
try:
potential_modifications = self.rounded_mass_to_name_dict[round(mass_shift, 0)]
potential_modifications_indices = self._binary_range_search(self.modification_df["monoisotopic_mass"].to_numpy(),mass_shift,self.mass_error)
print(self.modification_df[potential_modifications_indices[0]:potential_modifications_indices[1]])
potential_modifications = self.modification_df[potential_modifications_indices[0]:potential_modifications_indices[1]]["name"]
except KeyError:
return None
localized_modifications = []
for potential_mod in potential_modifications:

if (
self.name_to_mass_residue_dict[potential_mod].mass - self.mass_error
< mass_shift
< self.name_to_mass_residue_dict[potential_mod].mass + self.mass_error
):
localized_mod = self.get_localisation(
psm,
potential_mod,
self.name_to_mass_residue_dict[potential_mod].residues,
self.name_to_mass_residue_dict[potential_mod].restrictions,
)
if localized_mod:
localized_mod = self.get_localisation(
psm,
potential_mod,
self.name_to_mass_residue_dict[potential_mod].residues,
self.name_to_mass_residue_dict[potential_mod].restrictions,
)
if localized_mod:
localized_modifications.extend(localized_mod)
else:
continue

return localized_modifications if localized_modifications else None

def _binary_range_search(self, arr, target, error) -> tuple[int,int]:
"""
Find the indexes of values within a specified range in a ascending array.
Args:
arr (list of int/float): A sorted array in ascending order.
target (int/float): The midpoint value defining the center of the range.
error (int/float): The acceptable deviation from the target, defining the size of the range.
Returns:
tuple: A tuple containing the start and end indexes of the values that fall within the range
target - error, target + error]. If no values are found, returns an empty tuple.
"""

def binary_left_index(arr, value) -> int:
'''
Finds the index of the smallest element in a sorted array that is greater than or equal to a given value.
'''

left, right= 0,len(arr)-1
result = len(arr)

while left <= right:
mid = (left+right) // 2 # round to int

if arr[mid] >= value:
result = mid
right = mid - 1
else:
left = mid + 1
return result

def binary_right_index(arr, value) -> int:
'''
Finds the index of the biggest element in a sorted array that is less than or equal to a given value.
'''

left, right = 0,len(arr)-1
result = len(arr)

while left <= right:
mid = (left+right) // 2 # round to int

if arr[mid] <= value:
result = mid
left = mid + 1
else:
right = mid - 1
return result

left = binary_left_index(arr,target-error)
right = binary_right_index(arr,target+error)

if left <= right:
return (left, right)
else:
return ()

def _get_aa_sub_dict(self):
"""
Get dictionary with name as key and mass and residue as value.
Expand Down

0 comments on commit db285ed

Please sign in to comment.