Skip to content

Commit

Permalink
Replace parse_modified_fasta_sequence with constituents_of_modified_f…
Browse files Browse the repository at this point in the history
…asta
  • Loading branch information
wukevin committed Sep 10, 2024
1 parent e4029e1 commit 31661f6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
18 changes: 13 additions & 5 deletions chai_lab/data/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
_make_sym_ids,
)
from chai_lab.data.dataset.structure.chain import Chain
from chai_lab.data.parsing.fasta import parse_modified_fasta_sequence, read_fasta
from chai_lab.data.parsing.input_validation import identify_potential_entity_types
from chai_lab.data.parsing.fasta import get_residue_name, read_fasta
from chai_lab.data.parsing.input_validation import (
constituents_of_modified_fasta,
identify_potential_entity_types,
)
from chai_lab.data.parsing.structure.all_atom_entity_data import AllAtomEntityData
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.data.parsing.structure.residue import Residue, get_restype
Expand Down Expand Up @@ -95,10 +98,15 @@ def raw_inputs_to_entitites_data(
residues = get_lig_residues(smiles=input.sequence)

case EntityType.PROTEIN | EntityType.RNA | EntityType.DNA:
parsed_sequence: list[str] = parse_modified_fasta_sequence(
input.sequence, entity_type
parsed_sequence: list | None = constituents_of_modified_fasta(
input.sequence
)
residues = get_polymer_residues(parsed_sequence, entity_type)
assert parsed_sequence is not None
expanded_sequence = [
get_residue_name(r, entity_type=entity_type) if len(r) == 1 else r
for r in parsed_sequence
]
residues = get_polymer_residues(expanded_sequence, entity_type)
case _:
raise NotImplementedError
assert residues is not None
Expand Down
6 changes: 4 additions & 2 deletions chai_lab/data/parsing/fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def get_residue_name(
fasta_code: str,
entity_type: EntityType,
) -> str:
if len(fasta_code) != 1:
raise ValueError("Cannot handle non-single chars: {}".format(fasta_code))
match entity_type:
case EntityType.PROTEIN:
return restype_1to3_with_x.get(fasta_code, "UNK")
Expand All @@ -48,8 +50,8 @@ def get_residue_name(

def parse_modified_fasta_sequence(sequence: str, entity_type: EntityType) -> list[str]:
"""
Parses a fasta-like string containing modified residues
in brackets, returns a list of residue codes.
Parses a fasta-like string containing modified residues in brackets.
Returns a list of residue codes expanded to their full names (e.g., K > LYS)
"""
pattern = r"[A-Z]|\[[A-Z0-9]+\]"
residues = re.findall(pattern, sequence)
Expand Down
5 changes: 3 additions & 2 deletions chai_lab/data/parsing/input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

def constituents_of_modified_fasta(x: str) -> list[str] | None:
"""
Accepts RNA/DNA inputs: 'agtc', 'AGT[ASP]TG', etc. Does not accept SMILES strings.
Returns constituents, e.g, [A, G, T, ASP, T, G] or None if string is incorrect
Accepts RNA/DNA inputs: 'agtc', 'AGT(ASP)TG', etc. Does not accept SMILES strings.
Returns constituents, e.g, [A, G, T, ASP, T, G] or None if string is incorrect.
Everything in returned list is single character, except for blocks specified in brackets.
"""
x = x.strip().upper()
# it is a bit strange that digits are here, but [NH2] was in one protein
Expand Down

0 comments on commit 31661f6

Please sign in to comment.