From d6534f996b0c05cf8e94aac1de6a194772313f36 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 3 Sep 2024 17:40:27 +0100 Subject: [PATCH] Repo checks: check documented methods exist (#32320) --- utils/check_repo.py | 134 ++++++++++++++++++++------------------------ 1 file changed, 61 insertions(+), 73 deletions(-) diff --git a/utils/check_repo.py b/utils/check_repo.py index 47e17c546b333e..2f0e12c9cf51be 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -31,7 +31,6 @@ It has no auto-fix mode. """ -import inspect import os import re import sys @@ -421,22 +420,15 @@ def get_model_modules() -> List[str]: "modeling_auto", "modeling_encoder_decoder", "modeling_marian", - "modeling_mmbt", - "modeling_outputs", "modeling_retribert", - "modeling_utils", "modeling_flax_auto", "modeling_flax_encoder_decoder", - "modeling_flax_utils", "modeling_speech_encoder_decoder", "modeling_flax_speech_encoder_decoder", "modeling_flax_vision_encoder_decoder", "modeling_timm_backbone", "modeling_tf_auto", "modeling_tf_encoder_decoder", - "modeling_tf_outputs", - "modeling_tf_pytorch_utils", - "modeling_tf_utils", "modeling_tf_vision_encoder_decoder", "modeling_vision_encoder_decoder", ] @@ -450,8 +442,7 @@ def get_model_modules() -> List[str]: for submodule in dir(model_module): if submodule.startswith("modeling") and submodule not in _ignore_modules: modeling_module = getattr(model_module, submodule) - if inspect.ismodule(modeling_module): - modules.append(modeling_module) + modules.append(modeling_module) return modules @@ -913,19 +904,26 @@ def find_all_documented_objects() -> List[str]: Returns: `List[str]`: The list of all object names being documented. + `Dict[str, List[str]]`: A dictionary mapping the object name (full import path, e.g. + `integrations.PeftAdapterMixin`) to its documented methods """ documented_obj = [] - for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"): - with open(doc_file, "r", encoding="utf-8", newline="\n") as f: - content = f.read() - raw_doc_objs = re.findall(r"(?:autoclass|autofunction):: transformers.(\S+)\s+", content) - documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs] + documented_methods_map = {} for doc_file in Path(PATH_TO_DOC).glob("**/*.md"): with open(doc_file, "r", encoding="utf-8", newline="\n") as f: content = f.read() raw_doc_objs = re.findall(r"\[\[autodoc\]\]\s+(\S+)\s+", content) documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs] - return documented_obj + + for obj in raw_doc_objs: + obj_public_methods = re.findall(rf"\[\[autodoc\]\] {obj}((\n\s+-.*)+)", content) + # Some objects have no methods documented + if len(obj_public_methods) == 0: + continue + else: + documented_methods_map[obj] = re.findall(r"(?<=-\s).*", obj_public_methods[0][0]) + + return documented_obj, documented_methods_map # One good reason for not being documented is to be deprecated. Put in this list deprecated objects. @@ -1063,7 +1061,7 @@ def ignore_undocumented(name: str) -> bool: def check_all_objects_are_documented(): """Check all models are properly documented.""" - documented_objs = find_all_documented_objects() + documented_objs, documented_methods_map = find_all_documented_objects() modules = transformers._modules objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")] undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)] @@ -1072,8 +1070,41 @@ def check_all_objects_are_documented(): "The following objects are in the public init so should be documented:\n - " + "\n - ".join(undocumented_objs) ) - check_docstrings_are_in_md() check_model_type_doc_match() + check_public_method_exists(documented_methods_map) + + +def check_public_method_exists(documented_methods_map): + """Check that all explicitly documented public methods are defined in the corresponding class.""" + failures = [] + for obj, methods in documented_methods_map.items(): + # Let's ensure there is no repetition + if len(set(methods)) != len(methods): + failures.append(f"Error in the documentation of {obj}: there are repeated documented methods.") + + # Navigates into the object, given the full import path + nested_path = obj.split(".") + submodule = transformers + if len(nested_path) > 1: + nested_submodules = nested_path[:-1] + for submodule_name in nested_submodules: + if submodule_name == "transformers": + continue + submodule = getattr(submodule, submodule_name) + class_name = nested_path[-1] + obj_class = getattr(submodule, class_name) + # Checks that all explicitly documented methods are defined in the class + for method in methods: + if method == "all": # Special keyword to document all public methods + continue + if not hasattr(obj_class, method): + failures.append( + "The following public method is explicitly documented but not defined in the corresponding " + f"class. class: {obj}, method: {method}" + ) + + if len(failures) > 0: + raise Exception("\n".join(failures)) def check_model_type_doc_match(): @@ -1103,50 +1134,6 @@ def check_model_type_doc_match(): ) -# Re pattern to catch :obj:`xx`, :class:`xx`, :func:`xx` or :meth:`xx`. -_re_rst_special_words = re.compile(r":(?:obj|func|class|meth):`([^`]+)`") -# Re pattern to catch things between double backquotes. -_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)") -# Re pattern to catch example introduction. -_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE) - - -def is_rst_docstring(docstring: str) -> True: - """ - Returns `True` if `docstring` is written in rst. - """ - if _re_rst_special_words.search(docstring) is not None: - return True - if _re_double_backquotes.search(docstring) is not None: - return True - if _re_rst_example.search(docstring) is not None: - return True - return False - - -def check_docstrings_are_in_md(): - """Check all docstrings are written in md and nor rst.""" - files_with_rst = [] - for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"): - with open(file, encoding="utf-8") as f: - code = f.read() - docstrings = code.split('"""') - - for idx, docstring in enumerate(docstrings): - if idx % 2 == 0 or not is_rst_docstring(docstring): - continue - files_with_rst.append(file) - break - - if len(files_with_rst) > 0: - raise ValueError( - "The following files have docstrings written in rst:\n" - + "\n".join([f"- {f}" for f in files_with_rst]) - + "\nTo fix this run `doc-builder convert path_to_py_file` after installing `doc-builder`\n" - "(`pip install git+https://github.com/huggingface/doc-builder`)" - ) - - def check_deprecated_constant_is_up_to_date(): """ Check if the constant `DEPRECATED_MODELS` in `models/auto/configuration_auto.py` is up to date. @@ -1177,27 +1164,28 @@ def check_deprecated_constant_is_up_to_date(): def check_repo_quality(): - """Check all models are properly tested and documented.""" - print("Checking all models are included.") + """Check all models are tested and documented.""" + print("Repository-wide checks:") + print(" - checking all models are included.") check_model_list() - print("Checking all models are public.") + print(" - checking all models are public.") check_models_are_in_init() - print("Checking all models are properly tested.") + print(" - checking all models have tests.") check_all_decorator_order() check_all_models_are_tested() - print("Checking all objects are properly documented.") + print(" - checking all objects have documentation.") check_all_objects_are_documented() - print("Checking all models are in at least one auto class.") + print(" - checking all models are in at least one auto class.") check_all_models_are_auto_configured() - print("Checking all names in auto name mappings are defined.") + print(" - checking all names in auto name mappings are defined.") check_all_auto_object_names_being_defined() - print("Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.") + print(" - checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.") check_all_auto_mapping_names_in_config_mapping_names() - print("Checking all auto mappings could be imported.") + print(" - checking all auto mappings could be imported.") check_all_auto_mappings_importable() - print("Checking all objects are equally (across frameworks) in the main __init__.") + print(" - checking all objects are equally (across frameworks) in the main __init__.") check_objects_being_equally_in_main_init() - print("Checking the DEPRECATED_MODELS constant is up to date.") + print(" - checking the DEPRECATED_MODELS constant is up to date.") check_deprecated_constant_is_up_to_date()