Skip to content

Commit

Permalink
Repo checks: check documented methods exist (#32320)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Sep 3, 2024
1 parent 979d24e commit d6534f9
Showing 1 changed file with 61 additions and 73 deletions.
134 changes: 61 additions & 73 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
It has no auto-fix mode.
"""

import inspect
import os
import re
import sys
Expand Down Expand Up @@ -421,22 +420,15 @@ def get_model_modules() -> List[str]:
"modeling_auto",
"modeling_encoder_decoder",
"modeling_marian",
"modeling_mmbt",
"modeling_outputs",
"modeling_retribert",
"modeling_utils",
"modeling_flax_auto",
"modeling_flax_encoder_decoder",
"modeling_flax_utils",
"modeling_speech_encoder_decoder",
"modeling_flax_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder",
"modeling_timm_backbone",
"modeling_tf_auto",
"modeling_tf_encoder_decoder",
"modeling_tf_outputs",
"modeling_tf_pytorch_utils",
"modeling_tf_utils",
"modeling_tf_vision_encoder_decoder",
"modeling_vision_encoder_decoder",
]
Expand All @@ -450,8 +442,7 @@ def get_model_modules() -> List[str]:
for submodule in dir(model_module):
if submodule.startswith("modeling") and submodule not in _ignore_modules:
modeling_module = getattr(model_module, submodule)
if inspect.ismodule(modeling_module):
modules.append(modeling_module)
modules.append(modeling_module)
return modules


Expand Down Expand Up @@ -913,19 +904,26 @@ def find_all_documented_objects() -> List[str]:
Returns:
`List[str]`: The list of all object names being documented.
`Dict[str, List[str]]`: A dictionary mapping the object name (full import path, e.g.
`integrations.PeftAdapterMixin`) to its documented methods
"""
documented_obj = []
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
content = f.read()
raw_doc_objs = re.findall(r"(?:autoclass|autofunction):: transformers.(\S+)\s+", content)
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
documented_methods_map = {}
for doc_file in Path(PATH_TO_DOC).glob("**/*.md"):
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
content = f.read()
raw_doc_objs = re.findall(r"\[\[autodoc\]\]\s+(\S+)\s+", content)
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
return documented_obj

for obj in raw_doc_objs:
obj_public_methods = re.findall(rf"\[\[autodoc\]\] {obj}((\n\s+-.*)+)", content)
# Some objects have no methods documented
if len(obj_public_methods) == 0:
continue
else:
documented_methods_map[obj] = re.findall(r"(?<=-\s).*", obj_public_methods[0][0])

return documented_obj, documented_methods_map


# One good reason for not being documented is to be deprecated. Put in this list deprecated objects.
Expand Down Expand Up @@ -1063,7 +1061,7 @@ def ignore_undocumented(name: str) -> bool:

def check_all_objects_are_documented():
"""Check all models are properly documented."""
documented_objs = find_all_documented_objects()
documented_objs, documented_methods_map = find_all_documented_objects()
modules = transformers._modules
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
Expand All @@ -1072,8 +1070,41 @@ def check_all_objects_are_documented():
"The following objects are in the public init so should be documented:\n - "
+ "\n - ".join(undocumented_objs)
)
check_docstrings_are_in_md()
check_model_type_doc_match()
check_public_method_exists(documented_methods_map)


def check_public_method_exists(documented_methods_map):
"""Check that all explicitly documented public methods are defined in the corresponding class."""
failures = []
for obj, methods in documented_methods_map.items():
# Let's ensure there is no repetition
if len(set(methods)) != len(methods):
failures.append(f"Error in the documentation of {obj}: there are repeated documented methods.")

# Navigates into the object, given the full import path
nested_path = obj.split(".")
submodule = transformers
if len(nested_path) > 1:
nested_submodules = nested_path[:-1]
for submodule_name in nested_submodules:
if submodule_name == "transformers":
continue
submodule = getattr(submodule, submodule_name)
class_name = nested_path[-1]
obj_class = getattr(submodule, class_name)
# Checks that all explicitly documented methods are defined in the class
for method in methods:
if method == "all": # Special keyword to document all public methods
continue
if not hasattr(obj_class, method):
failures.append(
"The following public method is explicitly documented but not defined in the corresponding "
f"class. class: {obj}, method: {method}"
)

if len(failures) > 0:
raise Exception("\n".join(failures))


def check_model_type_doc_match():
Expand Down Expand Up @@ -1103,50 +1134,6 @@ def check_model_type_doc_match():
)


# Re pattern to catch :obj:`xx`, :class:`xx`, :func:`xx` or :meth:`xx`.
_re_rst_special_words = re.compile(r":(?:obj|func|class|meth):`([^`]+)`")
# Re pattern to catch things between double backquotes.
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
# Re pattern to catch example introduction.
_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE)


def is_rst_docstring(docstring: str) -> True:
"""
Returns `True` if `docstring` is written in rst.
"""
if _re_rst_special_words.search(docstring) is not None:
return True
if _re_double_backquotes.search(docstring) is not None:
return True
if _re_rst_example.search(docstring) is not None:
return True
return False


def check_docstrings_are_in_md():
"""Check all docstrings are written in md and nor rst."""
files_with_rst = []
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
with open(file, encoding="utf-8") as f:
code = f.read()
docstrings = code.split('"""')

for idx, docstring in enumerate(docstrings):
if idx % 2 == 0 or not is_rst_docstring(docstring):
continue
files_with_rst.append(file)
break

if len(files_with_rst) > 0:
raise ValueError(
"The following files have docstrings written in rst:\n"
+ "\n".join([f"- {f}" for f in files_with_rst])
+ "\nTo fix this run `doc-builder convert path_to_py_file` after installing `doc-builder`\n"
"(`pip install git+https://github.com/huggingface/doc-builder`)"
)


def check_deprecated_constant_is_up_to_date():
"""
Check if the constant `DEPRECATED_MODELS` in `models/auto/configuration_auto.py` is up to date.
Expand Down Expand Up @@ -1177,27 +1164,28 @@ def check_deprecated_constant_is_up_to_date():


def check_repo_quality():
"""Check all models are properly tested and documented."""
print("Checking all models are included.")
"""Check all models are tested and documented."""
print("Repository-wide checks:")
print(" - checking all models are included.")
check_model_list()
print("Checking all models are public.")
print(" - checking all models are public.")
check_models_are_in_init()
print("Checking all models are properly tested.")
print(" - checking all models have tests.")
check_all_decorator_order()
check_all_models_are_tested()
print("Checking all objects are properly documented.")
print(" - checking all objects have documentation.")
check_all_objects_are_documented()
print("Checking all models are in at least one auto class.")
print(" - checking all models are in at least one auto class.")
check_all_models_are_auto_configured()
print("Checking all names in auto name mappings are defined.")
print(" - checking all names in auto name mappings are defined.")
check_all_auto_object_names_being_defined()
print("Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.")
print(" - checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.")
check_all_auto_mapping_names_in_config_mapping_names()
print("Checking all auto mappings could be imported.")
print(" - checking all auto mappings could be imported.")
check_all_auto_mappings_importable()
print("Checking all objects are equally (across frameworks) in the main __init__.")
print(" - checking all objects are equally (across frameworks) in the main __init__.")
check_objects_being_equally_in_main_init()
print("Checking the DEPRECATED_MODELS constant is up to date.")
print(" - checking the DEPRECATED_MODELS constant is up to date.")
check_deprecated_constant_is_up_to_date()


Expand Down

0 comments on commit d6534f9

Please sign in to comment.