Skip to content

Commit

Permalink
fix regexes with escape sequence (huggingface#17943)
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 authored and viclzhu committed Jul 18, 2022
1 parent 91d925a commit e5b5497
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/transformers/dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def get_relative_imports(module_file):
content = f.read()

# Imports of the form `import .xxx`
relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from .xxx import yyy`
relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
# Unique-ify
return list(set(relative_imports))

Expand Down Expand Up @@ -122,9 +122,9 @@ def check_imports(filename):
content = f.read()

# Imports of the form `import xxx`
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from xxx import yyy`
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
# Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def dtype_byte_size(dtype):
"""
if dtype == torch.bool:
return 1 / 8
bit_search = re.search("[^\d](\d+)$", str(dtype))
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
Expand Down

0 comments on commit e5b5497

Please sign in to comment.