From 269e852134c5dee8484878a4145a458ae4fd302b Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 12 Jul 2023 15:38:27 +0200 Subject: [PATCH] unify docformatter config (#1642) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> additional_dependencies: [tomli] --- .pre-commit-config.yaml | 14 ++-------- pyproject.toml | 14 ++++------ setup.py | 3 +++ src/flash/audio/speech_recognition/collate.py | 1 + src/flash/core/classification.py | 1 + src/flash/core/data/batch.py | 1 + src/flash/core/data/data_module.py | 1 + .../core/data/io/classification_input.py | 1 + src/flash/core/data/io/input.py | 1 + src/flash/core/data/io/input_transform.py | 27 +++++++++++++++++++ src/flash/core/data/io/output.py | 1 + src/flash/core/data/io/output_transform.py | 3 +++ src/flash/core/data/splits.py | 1 + src/flash/core/data/utilities/data_frame.py | 2 ++ src/flash/core/data/utilities/loading.py | 4 +++ src/flash/core/data/utilities/paths.py | 4 +++ src/flash/core/data/utilities/samples.py | 2 ++ src/flash/core/data/utils.py | 1 + src/flash/core/finetuning.py | 1 + src/flash/core/heads.py | 1 + src/flash/core/integrations/fiftyone/utils.py | 1 + .../pytorch_forecasting/adapter.py | 1 + src/flash/core/model.py | 2 ++ src/flash/core/optimizers/lamb.py | 1 + src/flash/core/optimizers/lars.py | 2 ++ src/flash/core/optimizers/lr_scheduler.py | 1 + src/flash/core/registry.py | 1 + .../core/serve/_compat/cached_property.py | 1 + src/flash/core/serve/component.py | 6 +++++ src/flash/core/serve/composition.py | 1 + src/flash/core/serve/core.py | 7 +++++ src/flash/core/serve/dag/optimization.py | 6 +++++ src/flash/core/serve/dag/order.py | 6 +++++ src/flash/core/serve/dag/rewrite.py | 8 ++++++ src/flash/core/serve/dag/task.py | 9 +++++++ src/flash/core/serve/decorators.py | 2 ++ src/flash/core/serve/execution.py | 4 +++ src/flash/core/serve/interfaces/models.py | 1 + src/flash/core/serve/server.py | 2 ++ src/flash/core/serve/types/base.py | 3 +++ src/flash/core/serve/types/bbox.py | 1 + src/flash/core/serve/types/image.py | 1 + src/flash/core/serve/types/label.py | 1 + src/flash/core/serve/types/repeated.py | 1 + src/flash/core/serve/types/table.py | 1 + src/flash/core/serve/types/text.py | 1 + src/flash/core/serve/utils.py | 2 ++ src/flash/core/trainer.py | 3 +++ src/flash/core/utilities/imports.py | 2 ++ src/flash/core/utilities/lightning_cli.py | 8 ++++++ src/flash/core/utilities/stages.py | 1 + .../integrations/learn2learn.py | 2 ++ .../image/embedding/heads/vissl_heads.py | 1 + .../image/embedding/strategies/default.py | 1 + .../embedding/vissl/transforms/multicrop.py | 1 + .../embedding/vissl/transforms/utilities.py | 2 ++ .../image/face_detection/input_transform.py | 1 + .../pointcloud/detection/open3d_ml/app.py | 1 + .../pointcloud/segmentation/open3d_ml/app.py | 1 + .../open3d_ml/sequences_dataset.py | 1 + src/flash/template/classification/data.py | 1 + src/flash/text/seq2seq/core/model.py | 1 + src/flash/video/classification/model.py | 1 + src/flash/video/classification/utils.py | 1 + tests/core/serve/test_dag/test_order.py | 7 +++++ tests/core/serve/test_gridbase_validations.py | 3 +++ tests/core/utilities/test_embedder.py | 2 ++ tests/helpers/boring_model.py | 1 + tests/video/classification/test_model.py | 3 +++ 69 files changed, 181 insertions(+), 21 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e419d4577d..73a69f2b04 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,10 +51,8 @@ repos: rev: v1.7.3 hooks: - id: docformatter - args: - - "--in-place" - - "--wrap-summaries=120" - - "--wrap-descriptions=120" + additional_dependencies: [tomli] + args: ["--in-place"] - repo: https://github.com/psf/black rev: 23.3.0 @@ -62,14 +60,6 @@ repos: - id: black name: Format code - - repo: https://github.com/asottile/blacken-docs - rev: 1.14.0 - hooks: - - id: blacken-docs - args: - - "--line-length=120" - - "--skip-errors" - - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.0.276 hooks: diff --git a/pyproject.toml b/pyproject.toml index aaf1f15b77..254cc502b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,15 +48,11 @@ exclude_lines = [ line-length = 120 exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)" -[tool.isort] -known_first_party = [ - "flash", - "examples", - "tests", -] -skip_glob = [] -profile = "black" -line_length = 120 +[tool.docformatter] +recursive = true +wrap-summaries = 120 +wrap-descriptions = 120 +blank = true [tool.ruff] diff --git a/setup.py b/setup.py index c74b986a6d..d72d93b482 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: >>> _load_readme_description(_PATH_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE '
...' + """ path_readme = os.path.join(path_dir, "README.md") text = open(path_readme, encoding="utf-8").read() @@ -65,6 +66,7 @@ def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: bool = True 'arrow>=1.2.0, <=1.2.2 # strict' >>> _augment_requirement("arrow", unfreeze=True) 'arrow' + """ # filer all comments if comment_char in ln: @@ -95,6 +97,7 @@ def _load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: boo >>> path_req = os.path.join(_PATH_ROOT, "requirements") >>> _load_requirements(path_req, "docs.txt") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ['sphinx>=4.0', ...] + """ with open(os.path.join(path_dir, file_name)) as file: lines = [ln.strip() for ln in file.readlines()] diff --git a/src/flash/audio/speech_recognition/collate.py b/src/flash/audio/speech_recognition/collate.py index 0346bb04f4..90471bedf3 100644 --- a/src/flash/audio/speech_recognition/collate.py +++ b/src/flash/audio/speech_recognition/collate.py @@ -50,6 +50,7 @@ class DataCollatorCTCWithPadding: If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + """ processor: AutoProcessor diff --git a/src/flash/core/classification.py b/src/flash/core/classification.py index 90740340a3..394b824421 100644 --- a/src/flash/core/classification.py +++ b/src/flash/core/classification.py @@ -127,6 +127,7 @@ class ClassificationOutput(Output): Args: multi_label: If true, treats outputs as multi label logits. + """ def __init__(self, multi_label: bool = False): diff --git a/src/flash/core/data/batch.py b/src/flash/core/data/batch.py index a483375372..617e03a1f8 100644 --- a/src/flash/core/data/batch.py +++ b/src/flash/core/data/batch.py @@ -58,6 +58,7 @@ def default_uncollate(batch: Any) -> List[Any]: ValueError: If the input is a ``dict`` whose values are not all list-like. ValueError: If the input is a ``dict`` whose values are not all the same length. ValueError: If the input is not a ``dict`` or list-like. + """ if isinstance(batch, dict): if any(not _is_list_like_excluding_str(sub_batch) for sub_batch in batch.values()): diff --git a/src/flash/core/data/data_module.py b/src/flash/core/data/data_module.py index 12f310d18c..9118300c96 100644 --- a/src/flash/core/data/data_module.py +++ b/src/flash/core/data/data_module.py @@ -566,6 +566,7 @@ def _split_train_val( Returns: A tuple containing the training and validation datasets + """ if not isinstance(val_split, float) or (isinstance(val_split, float) and val_split > 1 or val_split < 0): diff --git a/src/flash/core/data/io/classification_input.py b/src/flash/core/data/io/classification_input.py index 4c5cd60622..31daf8273c 100644 --- a/src/flash/core/data/io/classification_input.py +++ b/src/flash/core/data/io/classification_input.py @@ -64,5 +64,6 @@ def format_target(self, target: Any) -> Any: Returns: The formatted target. + """ return getattr(self, "target_formatter", lambda x: x)(target) diff --git a/src/flash/core/data/io/input.py b/src/flash/core/data/io/input.py index 7e42c448f2..199aee6f00 100644 --- a/src/flash/core/data/io/input.py +++ b/src/flash/core/data/io/input.py @@ -89,6 +89,7 @@ def _has_len(data: Union[Sequence, Iterable]) -> bool: Args: data: The object to check for length support. + """ try: len(data) diff --git a/src/flash/core/data/io/input_transform.py b/src/flash/core/data/io/input_transform.py index f65af3d78c..d436ac681d 100644 --- a/src/flash/core/data/io/input_transform.py +++ b/src/flash/core/data/io/input_transform.py @@ -84,6 +84,7 @@ def per_sample_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) + """ pass @@ -97,6 +98,7 @@ def train_per_sample_transform(self) -> Callable: DataKeys.TARGET: ..., DataKeys.METADATA: ..., } + """ return self.per_sample_transform() @@ -121,6 +123,7 @@ def val_per_sample_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_sample_transform() @@ -134,6 +137,7 @@ def test_per_sample_transform(self) -> Callable: DataKeys.TARGET: ..., DataKeys.METADATA: ..., } + """ return self.per_sample_transform() @@ -158,6 +162,7 @@ def predict_per_sample_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_sample_transform() @@ -182,6 +187,7 @@ def serve_per_sample_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_sample_transform() @@ -210,6 +216,7 @@ def per_sample_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) + """ pass @@ -223,6 +230,7 @@ def train_per_sample_transform_on_device(self) -> Callable: DataKeys.TARGET: ..., DataKeys.METADATA: ..., } + """ return self.per_sample_transform_on_device() @@ -247,6 +255,7 @@ def val_per_sample_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_sample_transform_on_device() @@ -260,6 +269,7 @@ def test_per_sample_transform_on_device(self) -> Callable: DataKeys.TARGET: ..., DataKeys.METADATA: ..., } + """ return self.per_sample_transform_on_device() @@ -284,6 +294,7 @@ def predict_per_sample_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_sample_transform_on_device() @@ -308,6 +319,7 @@ def serve_per_sample_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def serve_per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_sample_transform_on_device() @@ -336,6 +348,7 @@ def per_batch_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) + """ pass @@ -349,6 +362,7 @@ def train_per_batch_transform(self) -> Callable: DataKeys.TARGET: ..., DataKeys.METADATA: ..., } + """ return self.per_batch_transform() @@ -373,6 +387,7 @@ def val_per_batch_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_batch_transform() @@ -386,6 +401,7 @@ def test_per_batch_transform(self) -> Callable: DataKeys.TARGET: ..., DataKeys.METADATA: ..., } + """ return self.per_batch_transform() @@ -410,6 +426,7 @@ def predict_per_batch_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_batch_transform() @@ -434,6 +451,7 @@ def serve_per_batch_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_batch_transform() @@ -462,6 +480,7 @@ def per_batch_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) + """ pass @@ -475,6 +494,7 @@ def train_per_batch_transform_on_device(self) -> Callable: DataKeys.TARGET: ..., DataKeys.METADATA: ..., } + """ return self.per_batch_transform_on_device() @@ -499,6 +519,7 @@ def val_per_batch_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_batch_transform_on_device() @@ -512,6 +533,7 @@ def test_per_batch_transform_on_device(self) -> Callable: DataKeys.TARGET: ..., DataKeys.METADATA: ..., } + """ return self.per_batch_transform_on_device() @@ -536,6 +558,7 @@ def predict_per_batch_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_batch_transform_on_device() @@ -560,6 +583,7 @@ def serve_per_batch_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def serve_per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) + """ return self.per_batch_transform_on_device() @@ -606,6 +630,7 @@ def _per_batch_transform(self, batch: Any, stage: RunningStage) -> Any: .. note:: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both are specified, uncollation has to be applied. + """ return self.current_transform(stage=stage, current_fn="per_batch_transform")(batch) @@ -620,6 +645,7 @@ def _per_sample_transform_on_device(self, sample: Any, stage: RunningStage) -> A specified, uncollation has to be applied. .. note:: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). + """ fn = self.current_transform(stage=stage, current_fn="per_sample_transform_on_device") if isinstance(sample, list): @@ -631,6 +657,7 @@ def _per_batch_transform_on_device(self, batch: Any, stage: RunningStage) -> Any .. note:: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). + """ return self.current_transform(stage=stage, current_fn="per_batch_transform_on_device")(batch) diff --git a/src/flash/core/data/io/output.py b/src/flash/core/data/io/output.py index 0b8e7467a8..e0765ca754 100644 --- a/src/flash/core/data/io/output.py +++ b/src/flash/core/data/io/output.py @@ -37,6 +37,7 @@ def transform(sample: Any) -> Any: Returns: The converted output. + """ return sample diff --git a/src/flash/core/data/io/output_transform.py b/src/flash/core/data/io/output_transform.py index 0e691ce51a..c0e799542e 100644 --- a/src/flash/core/data/io/output_transform.py +++ b/src/flash/core/data/io/output_transform.py @@ -25,6 +25,7 @@ def per_batch_transform(batch: Any) -> Any: """Transforms to apply on a whole batch before uncollation to individual samples. Can involve both CPU and Device transforms as this is not applied in separate workers. + """ return batch @@ -33,6 +34,7 @@ def per_sample_transform(sample: Any) -> Any: """Transforms to apply to a single sample after splitting up the batch. Can involve both CPU and Device transforms as this is not applied in separate workers. + """ return sample @@ -41,6 +43,7 @@ def uncollate(batch: Any) -> Any: """Uncollates a batch into single samples. Tries to preserve the type wherever possible. + """ return default_uncollate(batch) diff --git a/src/flash/core/data/splits.py b/src/flash/core/data/splits.py index a51e29fade..b5b8492fef 100644 --- a/src/flash/core/data/splits.py +++ b/src/flash/core/data/splits.py @@ -21,6 +21,7 @@ class SplitDataset(Properties, Dataset): split_ds = SplitDataset(dataset, indices=[10, 14, 25]) split_ds = SplitDataset(dataset, indices=[10, 10, 10, 14, 25], use_duplicated_indices=True) + """ def __init__( diff --git a/src/flash/core/data/utilities/data_frame.py b/src/flash/core/data/utilities/data_frame.py index d2f4d7fc8f..ab4f99015b 100644 --- a/src/flash/core/data/utilities/data_frame.py +++ b/src/flash/core/data/utilities/data_frame.py @@ -30,6 +30,7 @@ def resolve_targets(data_frame: pd.DataFrame, target_keys: Union[str, List[str]] Args: data_frame: The ``pd.DataFrame`` containing the target column / columns. target_keys: The column in the data frame (or a list of columns) from which to resolve the target. + """ if not isinstance(target_keys, List): return data_frame[target_keys].tolist() @@ -63,6 +64,7 @@ def resolve_files( root: The root path to use when resolving files. resolver: The resolver function to use. This function should receive the root and a file ID as input and return the path to an existing file. + """ if resolver is None: resolver = default_resolver diff --git a/src/flash/core/data/utilities/loading.py b/src/flash/core/data/utilities/loading.py index e169e4ec3b..0075309c8d 100644 --- a/src/flash/core/data/utilities/loading.py +++ b/src/flash/core/data/utilities/loading.py @@ -175,6 +175,7 @@ def load_image(file_path: str): Args: file_path: The image file to load. + """ return load(file_path, _image_loaders) @@ -186,6 +187,7 @@ def load_spectrogram(file_path: str, sampling_rate: int = 16000, n_fft: int = 40 file_path: The file to load. sampling_rate: The sampling rate to resample to if loading from an audio file. n_fft: The size of the FFT to use when creating a spectrogram from an audio file. + """ loaders = copy.copy(_spectrogram_loaders) loaders[AUDIO_EXTENSIONS] = partial(loaders[AUDIO_EXTENSIONS], sampling_rate=sampling_rate, n_fft=n_fft) @@ -198,6 +200,7 @@ def load_audio(file_path: str, sampling_rate: int = 16000): Args: file_path: The file to load. sampling_rate: The sampling rate to resample to. + """ loaders = { extensions: partial(loader, sampling_rate=sampling_rate) for extensions, loader in _audio_loaders.items() @@ -211,6 +214,7 @@ def load_data_frame(file_path: str, encoding: str = "utf-8"): Args: file_path: The file to load. encoding: The encoding to use when reading the file. + """ loaders = {extensions: partial(loader, encoding=encoding) for extensions, loader in _data_frame_loaders.items()} return load(file_path, loaders) diff --git a/src/flash/core/data/utilities/paths.py b/src/flash/core/data/utilities/paths.py index 7d8850070e..96939e0e15 100644 --- a/src/flash/core/data/utilities/paths.py +++ b/src/flash/core/data/utilities/paths.py @@ -32,6 +32,7 @@ def has_file_allowed_extension(filename: PATH_TYPE, extensions: Tuple[str, ...]) Returns: bool: True if the filename ends with one of given extensions + """ return str(filename).lower().endswith(extensions) @@ -59,6 +60,7 @@ def make_dataset( Returns: (files, targets) Tuple containing the list of files and corresponding list of targets. + """ files, targets = [], [] directory = os.path.expanduser(str(directory)) @@ -104,6 +106,7 @@ def list_subdirs(folder: PATH_TYPE) -> List[str]: Returns: The list of subdirectories. + """ return list(sorted_alphanumeric(d.name for d in os.scandir(str(folder)) if d.is_dir())) @@ -146,6 +149,7 @@ def filter_valid_files( Returns: The filtered lists. + """ if not isinstance(files, List): files = [files] diff --git a/src/flash/core/data/utilities/samples.py b/src/flash/core/data/utilities/samples.py index 70a2bdf8db..8f26462b1d 100644 --- a/src/flash/core/data/utilities/samples.py +++ b/src/flash/core/data/utilities/samples.py @@ -32,6 +32,7 @@ def to_sample(input: Any) -> Dict[str, Any]: Returns: A sample dictionary. + """ if isinstance(input, dict) and DataKeys.INPUT in input: return input @@ -51,6 +52,7 @@ def to_samples(inputs: List[Any], targets: Optional[List[Any]] = None) -> List[D Returns: A list of sample dictionaries. + """ if targets is None: return [to_sample(input) for input in inputs] diff --git a/src/flash/core/data/utils.py b/src/flash/core/data/utils.py index 431bb0ad80..e142615d74 100644 --- a/src/flash/core/data/utils.py +++ b/src/flash/core/data/utils.py @@ -76,6 +76,7 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: >>> download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data") >>> os.listdir("./data") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE [...] + """ # Disable warning about making an insecure request urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) diff --git a/src/flash/core/finetuning.py b/src/flash/core/finetuning.py index 71b684de4b..be6dfb99cb 100644 --- a/src/flash/core/finetuning.py +++ b/src/flash/core/finetuning.py @@ -221,6 +221,7 @@ class FlashDeepSpeedFinetuning(FlashBaseFinetuning): DeepSpeed cannot store and load its parameters when working with Lightning. So FlashDeepSpeedFinetuning overrides `_store` to not store its parameters. + """ def _store( diff --git a/src/flash/core/heads.py b/src/flash/core/heads.py index a4a5416633..160d9b2a2f 100644 --- a/src/flash/core/heads.py +++ b/src/flash/core/heads.py @@ -30,6 +30,7 @@ def _load_linear_head(num_features: int, num_classes: int) -> nn.Module: Returns: nn.Module: Linear head. + """ return nn.Linear(num_features, num_classes) diff --git a/src/flash/core/integrations/fiftyone/utils.py b/src/flash/core/integrations/fiftyone/utils.py index 6c12271115..8a14dd21a5 100644 --- a/src/flash/core/integrations/fiftyone/utils.py +++ b/src/flash/core/integrations/fiftyone/utils.py @@ -60,6 +60,7 @@ def visualize( Returns: a :class:`fiftyone:fiftyone.core.session.Session` + """ if flash._IS_TESTING: return None diff --git a/src/flash/core/integrations/pytorch_forecasting/adapter.py b/src/flash/core/integrations/pytorch_forecasting/adapter.py index a9fc69bf68..d84277c121 100644 --- a/src/flash/core/integrations/pytorch_forecasting/adapter.py +++ b/src/flash/core/integrations/pytorch_forecasting/adapter.py @@ -39,6 +39,7 @@ class PatchTimeSeriesDataSet(TimeSeriesDataSet): """Hack to prevent index construction or data validation / conversion when instantiating model. This enables the ``TimeSeriesDataSet`` to be created from a single row of data. + """ def _construct_index(self, data: DataFrame, predict_mode: bool) -> DataFrame: diff --git a/src/flash/core/model.py b/src/flash/core/model.py index df10337fd8..df1c92399c 100644 --- a/src/flash/core/model.py +++ b/src/flash/core/model.py @@ -330,6 +330,7 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, FineTuningHooks (val_metrics): ModuleDict() (test_metrics): ModuleDict() ) + """ optimizers_registry: FlashRegistry = _OPTIMIZERS_REGISTRY @@ -381,6 +382,7 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: Returns: A dict containing both the loss and relevant metrics + """ x, y = batch y_hat = self(x) diff --git a/src/flash/core/optimizers/lamb.py b/src/flash/core/optimizers/lamb.py index c2c7aa9db4..3485ebcbb8 100644 --- a/src/flash/core/optimizers/lamb.py +++ b/src/flash/core/optimizers/lamb.py @@ -105,6 +105,7 @@ def step(self, closure=None): Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ loss = None if closure is not None: diff --git a/src/flash/core/optimizers/lars.py b/src/flash/core/optimizers/lars.py index 5cd288e53a..86e6058e14 100644 --- a/src/flash/core/optimizers/lars.py +++ b/src/flash/core/optimizers/lars.py @@ -73,6 +73,7 @@ class LARS(Optimizer): Parameters with weight decay set to 0 will automatically be excluded from layer-wise LR scaling. This is to ensure consistency with papers like SimCLR and BYOL. + """ def __init__( @@ -121,6 +122,7 @@ def step(self, closure=None): Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ loss = None if closure is not None: diff --git a/src/flash/core/optimizers/lr_scheduler.py b/src/flash/core/optimizers/lr_scheduler.py index 924683cdd9..1ecb59a2f9 100644 --- a/src/flash/core/optimizers/lr_scheduler.py +++ b/src/flash/core/optimizers/lr_scheduler.py @@ -63,6 +63,7 @@ class LinearWarmupCosineAnnealingLR(_LRScheduler): ... scheduler.step(epoch) ... # train(...) ... # validate(...) + """ def __init__( diff --git a/src/flash/core/registry.py b/src/flash/core/registry.py index b968ee1934..924cce8219 100644 --- a/src/flash/core/registry.py +++ b/src/flash/core/registry.py @@ -158,6 +158,7 @@ def __call__( """This function is used to register new functions to the registry along their metadata. Functions can be filtered using metadata using the ``get`` function. + """ if providers is not None: metadata["providers"] = providers diff --git a/src/flash/core/serve/_compat/cached_property.py b/src/flash/core/serve/_compat/cached_property.py index 50327f8d3f..29b9592cb4 100644 --- a/src/flash/core/serve/_compat/cached_property.py +++ b/src/flash/core/serve/_compat/cached_property.py @@ -3,6 +3,7 @@ cached_property() - computed once per instance, cached as attribute credits: https://github.com/penguinolog/backports.cached_property + """ __all__ = ("cached_property",) diff --git a/src/flash/core/serve/component.py b/src/flash/core/serve/component.py index c267ea49b8..8ee20b540e 100644 --- a/src/flash/core/serve/component.py +++ b/src/flash/core/serve/component.py @@ -55,6 +55,7 @@ class to perform the analysis on ------ SyntaxError If parameters are not specified correctly. + """ params = inspect.signature(cls.__init__).parameters if len(params) > 3: @@ -89,6 +90,7 @@ def _validate_model_args( If an empty iterable is passed as the model argument TypeError If the args do not contain properly formatted model refences + """ if isiterable(args) and len(args) == 0: raise ValueError(f"Iterable args={args} must have length >= 1") @@ -122,6 +124,7 @@ def _validate_config_args(config: Optional[Dict[str, Union[str, int, float, byte If ``config`` is a dict with invalid key/values ValueError If ``config`` is a dict with 0 arguments + """ if config is None: return @@ -183,6 +186,7 @@ def __call__(cls, *args, **kwargs): super().__call__() within metaclass means: return instance created by calling metaclass __prepare__ -> __new__ -> __init__ + """ klass = super().__call__(*args, **kwargs) klass._flashserve_meta_ = replace(klass._flashserve_meta_) @@ -203,6 +207,7 @@ class ModelComponent(metaclass=FlashServeMeta): assets, etc. The specification must be YAML serializable and loadable to/from a fully initialized instance. It must contain the minimal set of information necessary to find and initialize its dependencies (assets) and itself. + """ _flashserve_meta_: Optional[Union[BoundMeta, UnboundMeta]] = None @@ -211,6 +216,7 @@ def __flashserve_init__(self, models, *, config=None): """Do a bunch of setup. instance's __flashserve_init__ calls subclass __init__ in turn. + """ _validate_model_args(models) _validate_config_args(config) diff --git a/src/flash/core/serve/composition.py b/src/flash/core/serve/composition.py index d627c02995..616464bb9a 100644 --- a/src/flash/core/serve/composition.py +++ b/src/flash/core/serve/composition.py @@ -63,6 +63,7 @@ class Composition(ServerMixin): which provides introspection of components, endpoints, etc. * We plan to add some user-facing API to the ``Composition`` object which allows for modification of the composition. + """ _uid_comps: Dict[str, ModelComponent] diff --git a/src/flash/core/serve/core.py b/src/flash/core/serve/core.py index c55be4641f..ce947b3122 100644 --- a/src/flash/core/serve/core.py +++ b/src/flash/core/serve/core.py @@ -32,6 +32,7 @@ class Endpoint: outputs The full name of a component output. Typically, specified by just passing in the component parameter attribute (i.e.``component.outputs.bar``). + """ route: str @@ -99,6 +100,7 @@ class Servable: ---- * How to handle ``__init__`` args for ``torch.nn.Module`` * How to handle ``__init__`` args not recorded in hparams of ``pl.LightningModule`` + """ @requires("serve") @@ -151,6 +153,7 @@ class Connection(NamedTuple): * This data structure should not be instantiated directly! The class_methods attached to the class are the indended mechanisms to create a new instance. + """ source_component: str @@ -191,6 +194,7 @@ class Parameter: Which component this type is associated with position Position in the while exposing it i.e `inputs` or `outputs` + """ name: str @@ -220,6 +224,7 @@ def __terminate_invalid_connection_request(self, other: "Parameter", dunder_meth TypeError, RuntimeError if the verification fails, we throw an exception to stop the connection from being created. + """ # assert this is actually a class object we can compare against. if not isinstance(other, self.__class__) or (other.__class__ != self.__class__): @@ -313,6 +318,7 @@ def make_parameter_container(data: Dict[str, Parameter]) -> ParameterContainer: * parameter name must be valid python attribute (identifier) and cannot be a builtin keyword. input names should have been validated by this point. + """ dataclass_fields = [(param_name, type(param)) for param_name, param in data.items()] ParameterContainer = make_dataclass( @@ -335,6 +341,7 @@ def make_param_dict( Tuple[Dict[str, Parameter], Dict[str, Parameter]] Element[0] == Input parameter dict Element[1] == Output parameter dict. + """ flashserve_inp_params, flashserve_out_params = {}, {} for inp_key, inp_dtype in inputs.items(): diff --git a/src/flash/core/serve/dag/optimization.py b/src/flash/core/serve/dag/optimization.py index 9ccb96709e..5b8de1196a 100644 --- a/src/flash/core/serve/dag/optimization.py +++ b/src/flash/core/serve/dag/optimization.py @@ -32,6 +32,7 @@ def cull(dsk, keys): dsk: culled graph dependencies: Dict mapping {key: [deps]}. Useful side effect to accelerate other optimizations, notably fuse. + """ if not isinstance(keys, (list, set)): keys = [keys] @@ -113,6 +114,7 @@ def fuse_linear(dsk, keys=None, dependencies=None, rename_keys=True): dsk: output graph with keys fused dependencies: dict mapping dependencies after fusion. Useful side effect to accelerate other downstream optimizations. + """ if keys is not None and not isinstance(keys, set): if not isinstance(keys, list): @@ -236,6 +238,7 @@ def inline(dsk, keys=None, inline_constants=True, dependencies=None): {'x': 1, 'y': (, 1), 'z': (, 1, (, 1))} >>> inline(d, keys='y', inline_constants=False) # doctest: +ELLIPSIS {'x': 1, 'y': (, 'x'), 'z': (, 'x', (, 'x'))} + """ if dependencies and isinstance(next(iter(dependencies.values())), list): dependencies = {k: set(v) for k, v in dependencies.items()} @@ -296,6 +299,7 @@ def inline_functions(dsk, output, fast_functions=None, inline_constants=False, d 'out': (, 'i', ( at ...>, 'y')), 'i': (, 'x'), 'x': 1} + """ if not fast_functions: return dsk @@ -338,6 +342,7 @@ def functions_of(task): >>> task = (add, (mul, 1, 2), (inc, 3)) >>> sorted(functions_of(task), key=str) # doctest: +ELLIPSIS [, , ] + """ funcs = set() @@ -875,6 +880,7 @@ class SubgraphCallable: A list of keys to be used as arguments to the callable. name : str, optional The name to use for the function. + """ __slots__ = ("dsk", "outkey", "inkeys", "name") diff --git a/src/flash/core/serve/dag/order.py b/src/flash/core/serve/dag/order.py index 9691bb619e..69447fad22 100644 --- a/src/flash/core/serve/dag/order.py +++ b/src/flash/core/serve/dag/order.py @@ -74,6 +74,7 @@ difference exists between two keys, use the key name to break ties. This relies on the regularity of graph constructors like dask.array to be a good proxy for ordering. This is usually a good idea and a sane default. + """ from collections import defaultdict @@ -110,6 +111,7 @@ def order(dsk, dependencies=None): >>> dsk = {'a': 1, 'b': 2, 'c': (inc, 'a'), 'd': (add, 'b', 'c')} >>> order(dsk) {'a': 0, 'c': 1, 'b': 2, 'd': 3} + """ if not dsk: return {} @@ -159,6 +161,7 @@ def dependents_key(x): """Choose a path from our starting task to our tactical goal. This path is connected to a large goal, but focuses on completing a small goal and being memory efficient. + """ return ( # Focus on being memory-efficient @@ -172,6 +175,7 @@ def dependencies_key(x): """Choose which dependency to run as part of a reverse DFS. This is very similar to both ``initial_stack_key``. + """ num_dependents = len(dependents[x]) ( @@ -629,6 +633,7 @@ def ndependencies(dependencies, dependents): ------- num_dependencies: Dict[key, int] total_dependencies: Dict[key, int] + """ num_needed = {} result = {} @@ -672,6 +677,7 @@ class StrComparable: >>> StrComparable('a') < StrComparable(1) False + """ __slots__ = ("obj",) diff --git a/src/flash/core/serve/dag/rewrite.py b/src/flash/core/serve/dag/rewrite.py index 2535554a5d..aebe058ba6 100644 --- a/src/flash/core/serve/dag/rewrite.py +++ b/src/flash/core/serve/dag/rewrite.py @@ -46,6 +46,7 @@ class Traverser: current The head of the current element in the traversal. This is simply `head` applied to the attribute `term`. + """ def __init__(self, term, stack=None): @@ -64,6 +65,7 @@ def copy(self): """Copy the traverser in its current state. This allows the traversal to be pushed onto a stack, for easy backtracking. + """ return Traverser(self.term, deque(self._stack)) @@ -92,6 +94,7 @@ class Token: """A token object. Used to express certain objects in the traversal of a task or pattern. + """ def __init__(self, name): @@ -172,6 +175,7 @@ class RewriteRule: ... else: ... return list, x >>> rule = RewriteRule(lhs, repl_list, variables) + """ def __init__(self, lhs, rhs, vars=()): @@ -234,6 +238,7 @@ class RuleSet: ---------- rules : list A list of `RewriteRule`s included in the `RuleSet`. + """ def __init__(self, *rules): @@ -255,6 +260,7 @@ def add(self, rule): Parameters ---------- rule : RewriteRule + """ if not isinstance(rule, RewriteRule): @@ -288,6 +294,7 @@ def iter_matches(self, term): Tuples of `(rule, subs)`, where `rule` is the rewrite rule being matched, and `subs` is a dictionary mapping the variables in the lhs of the rule to their matching values in the term. + """ S = Traverser(term) @@ -423,6 +430,7 @@ def _process_match(rule, syms): A dictionary of {vars : subterms} describing the substitution to make the pattern equivalent with the term. Returns `None` if the match is invalid. + """ subs = {} diff --git a/src/flash/core/serve/dag/task.py b/src/flash/core/serve/dag/task.py index 2e89272b88..59bb7875f8 100644 --- a/src/flash/core/serve/dag/task.py +++ b/src/flash/core/serve/dag/task.py @@ -19,6 +19,7 @@ def ishashable(x): True >>> ishashable([1]) False + """ try: hash(x) @@ -37,6 +38,7 @@ def istask(x): True >>> istask(1) False + """ return type(x) is tuple and x and callable(x[0]) @@ -79,6 +81,7 @@ def _execute_task(arg, cache): [[1, 2], [2, 1]] >>> _execute_task('foo', cache) # Passes through on non-keys 'foo' + """ if isinstance(arg, list): return [_execute_task(a, cache) for a in arg] @@ -120,6 +123,7 @@ def get(dsk: dict, out: Sequence[str], cache: dict = None, sortkeys: List[str] = 2 >>> get(d, 'y', sortkeys=['x', 'y']) 2 + """ for k in flatten(out) if isinstance(out, list) else [out]: if k not in dsk: @@ -161,6 +165,7 @@ def get_dependencies(dsk, key=None, task=no_default, as_list=False): {'x'} >>> get_dependencies(dsk, task=(inc, 'x')) # provide tasks directly {'x'} + """ if key is not None: arg = dsk[key] @@ -205,6 +210,7 @@ def get_deps(dsk): {'a': set(), 'b': {'a'}, 'c': {'b'}} >>> dict(dependents) {'a': {'b'}, 'b': {'c'}, 'c': set()} + """ dependencies = {k: get_dependencies(dsk, task=v) for k, v in dsk.items()} dependents = reverse_dict(dependencies) @@ -261,6 +267,7 @@ def subs(task, key, val): >>> from flash.core.serve.dag.utils_test import inc >>> subs((inc, 'x'), 'x', 1) # doctest: +ELLIPSIS (, 1) + """ type_task = type(task) if not (type_task is tuple and task and callable(task[0])): # istask(task): @@ -299,6 +306,7 @@ def _toposort(dsk, keys=None, returncycle=False, dependencies=None): """Stack-based depth-first search traversal. This is based on Tarjan's method for topological sorting (see wikipedia for pseudocode). + """ if keys is None: keys = dsk @@ -433,6 +441,7 @@ def quote(x): >>> from flash.core.serve.dag.utils_test import add >>> quote((add, 1, 2)) (literal,) + """ if istask(x) or type(x) is list or type(x) is dict: return (literal(x),) diff --git a/src/flash/core/serve/decorators.py b/src/flash/core/serve/decorators.py index 0675d037ee..858b2139c5 100644 --- a/src/flash/core/serve/decorators.py +++ b/src/flash/core/serve/decorators.py @@ -107,6 +107,7 @@ def _validate_expose_inputs_outputs_args(kwargs: Dict[str, BaseType]): >>> out = {'out': Number()} >>> _validate_expose_inputs_outputs_args(inp) >>> _validate_expose_inputs_outputs_args(out) + """ if not isinstance(kwargs, dict): raise TypeError(f"`expose` values must be {dict}. recieved {kwargs}") @@ -152,6 +153,7 @@ def expose(inputs: Dict[str, BaseType], outputs: Dict[str, BaseType]): TODO ---- * Examples in the docstring. + """ _validate_expose_inputs_outputs_args(inputs) _validate_expose_inputs_outputs_args(outputs) diff --git a/src/flash/core/serve/execution.py b/src/flash/core/serve/execution.py index 3b555660e0..330ac8242a 100644 --- a/src/flash/core/serve/execution.py +++ b/src/flash/core/serve/execution.py @@ -67,6 +67,7 @@ class TaskComposition: pre_optimization_dsk Merged component `_dsk` subgraphs (without payload / result mapping or connections applied.) + """ __slots__ = ( @@ -112,6 +113,7 @@ class UnprocessedTaskDask: map of ouput (results) key to output task key output_keys keys to get as results + """ __slots__ = ( @@ -150,6 +152,7 @@ def _process_initial( Returns ------- UnprocessedTaskDask + """ # mapping payload input keys -> serialized keys / tasks @@ -256,6 +259,7 @@ def build_composition( ``C_2_1 deserailize``from ``C_2`` / ``C_1``, we see here that since endpoints define the path through the DAG, we cannot eliminate them entirely either. + """ initial_task_dsk = _process_initial(endpoint_protocol, components) diff --git a/src/flash/core/serve/interfaces/models.py b/src/flash/core/serve/interfaces/models.py index 4d6c84b5b7..2177640b14 100644 --- a/src/flash/core/serve/interfaces/models.py +++ b/src/flash/core/serve/interfaces/models.py @@ -37,6 +37,7 @@ class EndpointProtocol: class initializer. Component inputs & outputs (as defined in `@expose` object decorations) dtype method (`serialize` and `deserialize`) type hints are inspected in order to constuct a specification unique to the endpoint, they are returned as subclasses of pydantic ``BaseModel``. + """ def __init__(self, name: str, endpoint: "Endpoint", components: Dict[str, "ModelComponent"]): diff --git a/src/flash/core/serve/server.py b/src/flash/core/serve/server.py index aeaf00c034..ed1af05222 100644 --- a/src/flash/core/serve/server.py +++ b/src/flash/core/serve/server.py @@ -20,6 +20,7 @@ class ServerMixin: debug If the server should be started up in debug mode. By default, False. testing If the server should return the ``app`` instance instead of blocking the process (via running the ``app`` in ``uvicorn``). This is used when taking advantage of a server ``TestClient``. By default, False + """ DEBUG: bool @@ -37,6 +38,7 @@ def serve(self, host: str = "127.0.0.1", port: int = 8000): host address to run the server on port port number to expose the running server on + """ if FLASH_DISABLE_SERVE: return None diff --git a/src/flash/core/serve/types/base.py b/src/flash/core/serve/types/base.py index 6ef42a8a2f..d530a9c42d 100644 --- a/src/flash/core/serve/types/base.py +++ b/src/flash/core/serve/types/base.py @@ -30,6 +30,7 @@ def deserialize(self, text: str, language: str): .. code-block:: python {"text": "some string", "language": "en"} + """ @cached_property @@ -54,6 +55,7 @@ def deserialize(self, *args, **kwargs): # pragma: no cover """Take the inputs from the network and deserialize/convert them. Output from this method will go to the exposed method as arguments. + """ raise NotImplementedError @@ -64,5 +66,6 @@ def packed_deserialize(self, kwargs): sophisticated datatypes (such as Repeated) where the developer wants to dictate how the unpacking happens. For simple cases like Image or Bbox etc., developer would never need to know the existence of this. Task graph would never call deserialize directly but always call this method. + """ return self.deserialize(**kwargs) diff --git a/src/flash/core/serve/types/bbox.py b/src/flash/core/serve/types/bbox.py index ba9a98184d..e85d77d8f1 100644 --- a/src/flash/core/serve/types/bbox.py +++ b/src/flash/core/serve/types/bbox.py @@ -17,6 +17,7 @@ class BBox(BaseType): like Javascript to use a dictionary with ``x1, y1, x2 and y2`` as keys, we went with DL convention which is to use a list/tuple in which four floats are arranged in the same ``order -> x1, y1, x2, y2`` + """ def __post_init__(self): diff --git a/src/flash/core/serve/types/image.py b/src/flash/core/serve/types/image.py index e20b94d884..8641916fda 100644 --- a/src/flash/core/serve/types/image.py +++ b/src/flash/core/serve/types/image.py @@ -40,6 +40,7 @@ class Image(BaseType): "I": 1, # (32-bit signed integer pixels) "F": 1, # (32-bit floating point pixels) } + """ height: Optional[int] = None diff --git a/src/flash/core/serve/types/label.py b/src/flash/core/serve/types/label.py index a5ad295016..cb1da78a2a 100644 --- a/src/flash/core/serve/types/label.py +++ b/src/flash/core/serve/types/label.py @@ -21,6 +21,7 @@ class Label(BaseType): classes A list, tuple or a dict of classes. If it's list or a tuple, index of the class, is the key. If it's a dictionary, the key must be an integer + """ path: Union[str, Path, None] = field(default=None) diff --git a/src/flash/core/serve/types/repeated.py b/src/flash/core/serve/types/repeated.py index 5efa86902b..63498d87fa 100644 --- a/src/flash/core/serve/types/repeated.py +++ b/src/flash/core/serve/types/repeated.py @@ -18,6 +18,7 @@ class Repeated(BaseType): Optional parameter specifying if there is a maximum length of the repeated elements (`int > 0`). If `max_len=None`, there can be any number of repeated elements. By default: `None`. + """ dtype: BaseType diff --git a/src/flash/core/serve/types/table.py b/src/flash/core/serve/types/table.py index 7fe1fb7a33..a073af02af 100644 --- a/src/flash/core/serve/types/table.py +++ b/src/flash/core/serve/types/table.py @@ -53,6 +53,7 @@ class Table(BaseType): * It might be better to remove pandas dependency to gain performance however we are offloading the validation logic to pandas which would have been painful if we were to do custom built logic + """ column_names: List[str] diff --git a/src/flash/core/serve/types/text.py b/src/flash/core/serve/types/text.py index dfeda9a59d..62586b1c96 100644 --- a/src/flash/core/serve/types/text.py +++ b/src/flash/core/serve/types/text.py @@ -22,6 +22,7 @@ class Text(BaseType): TODO: Allow other arguments such as language, max_len etc. Add guidelines to write custom tokenizer + """ tokenizer: Union[str, Any] diff --git a/src/flash/core/serve/utils.py b/src/flash/core/serve/utils.py index 472493e47c..67585d6105 100644 --- a/src/flash/core/serve/utils.py +++ b/src/flash/core/serve/utils.py @@ -9,6 +9,7 @@ def fn_outputs_to_keyed_map(serialize_fn_out_keys, fn_output) -> Dict[str, Any]: """Convert outputs of a function to a dict of `{result_name: values}` accepts function outputs which are sequence, dict, or object. + """ if len(serialize_fn_out_keys) == 1: if not isinstance(fn_output, dict): @@ -33,6 +34,7 @@ def download_file(url: str, *, download_path: Optional[Path] = None) -> str: ---- * cleanup on error * allow specific file names + """ fname = f"{url.split('/')[-1]}" fpath = str(download_path.absolute()) if download_path is not None else f"./{fname}" diff --git a/src/flash/core/trainer.py b/src/flash/core/trainer.py index dce8d1b011..fe7e39b83a 100644 --- a/src/flash/core/trainer.py +++ b/src/flash/core/trainer.py @@ -78,6 +78,7 @@ class Trainer(PlTrainer): >>> Trainer() # doctest: +ELLIPSIS <...trainer.Trainer object at ...> + """ @_defaults_from_env_vars @@ -184,6 +185,7 @@ def predict( Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. + """ # Note: Prediction on TPU device with multi cores is not supported yet if isinstance(self.accelerator, TPUAccelerator) and self.num_devices > 1: @@ -262,5 +264,6 @@ def configure_optimizers(self): optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches ) return [optimizer], [scheduler] + """ return super().estimated_stepping_batches diff --git a/src/flash/core/utilities/imports.py b/src/flash/core/utilities/imports.py index fd2ff6d361..cb118ddf86 100644 --- a/src/flash/core/utilities/imports.py +++ b/src/flash/core/utilities/imports.py @@ -186,6 +186,7 @@ def lazy_import(module_name, callback=None): Returns: a proxy module object that will be lazily imported when first used + """ return LazyModule(module_name, callback=callback) @@ -197,6 +198,7 @@ class LazyModule(types.ModuleType): module_name: the fully-qualified module name to import callback (None): a callback function to call before importing the module + """ def __init__(self, module_name, callback=None): diff --git a/src/flash/core/utilities/lightning_cli.py b/src/flash/core/utilities/lightning_cli.py index e288c964e4..898df68854 100644 --- a/src/flash/core/utilities/lightning_cli.py +++ b/src/flash/core/utilities/lightning_cli.py @@ -30,6 +30,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: For full details of accepted arguments see `ArgumentParser.__init__ `_. + """ super().__init__(*args, **kwargs) self.add_argument( @@ -56,6 +57,7 @@ def add_lightning_class_args( lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. nested_key: Name of the nested namespace to store arguments. subclass_mode: Whether allow any subclass of the given class. + """ if callable(lightning_class) and not inspect.isclass(lightning_class): lightning_class = class_from_function(lightning_class) @@ -91,6 +93,7 @@ def add_optimizer_args( optimizer_class: Any subclass of torch.optim.Optimizer. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. + """ if isinstance(optimizer_class, tuple): assert all(issubclass(o, Optimizer) for o in optimizer_class) @@ -119,6 +122,7 @@ def add_lr_scheduler_args( lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. + """ if isinstance(lr_scheduler_class, tuple): assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) @@ -141,6 +145,7 @@ class SaveConfigCallback(Callback): Raises: RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run + """ def __init__( @@ -309,6 +314,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: Args: parser: The argument parser object to which arguments can be added + """ def link_optimizers_and_lr_schedulers(self) -> None: @@ -359,6 +365,7 @@ def add_configure_optimizers_method_to_model(self) -> None: If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', then a `configure_optimizers` method is automatically implemented in the model class. + """ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: @@ -453,6 +460,7 @@ def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) - Returns: The instantiated class object. + """ kwargs = init.get("init_args", {}) if not isinstance(args, tuple): diff --git a/src/flash/core/utilities/stages.py b/src/flash/core/utilities/stages.py index 5e6e653580..995bbd95e9 100644 --- a/src/flash/core/utilities/stages.py +++ b/src/flash/core/utilities/stages.py @@ -26,6 +26,7 @@ class RunningStage(LightningEnum): - ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING`` - ``TrainerFn.SERVING`` - ``RunningStage.SERVING`` - ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}`` + """ TRAINING = "train" diff --git a/src/flash/image/classification/integrations/learn2learn.py b/src/flash/image/classification/integrations/learn2learn.py index 6a66fc08a6..4939135256 100644 --- a/src/flash/image/classification/integrations/learn2learn.py +++ b/src/flash/image/classification/integrations/learn2learn.py @@ -42,6 +42,7 @@ def __init__( epoch_length: The expected epoch length. This requires to be divisible by devices. devices: Number of devices being used. collate_fn: The collate_fn to be applied on multiple tasks + """ self.tasks = tasks self.epoch_length = epoch_length @@ -97,6 +98,7 @@ def __init__( num_workers: Number of workers to be provided to the DataLoader. epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size). seed: The seed will be used on __iter__ call and should be the same for all processes. + """ self.taskset = taskset self.global_rank = global_rank diff --git a/src/flash/image/embedding/heads/vissl_heads.py b/src/flash/image/embedding/heads/vissl_heads.py index 16cbf16d96..4e69f0b258 100644 --- a/src/flash/image/embedding/heads/vissl_heads.py +++ b/src/flash/image/embedding/heads/vissl_heads.py @@ -41,6 +41,7 @@ class SimCLRHead(nn.Module): model_config: Model config AttrDict from VISSL dims: list of dimensions for creating a projection head use_bn: use batch-norm after each linear layer or not + """ def __init__( diff --git a/src/flash/image/embedding/strategies/default.py b/src/flash/image/embedding/strategies/default.py index 423a939128..f7e61f971d 100644 --- a/src/flash/image/embedding/strategies/default.py +++ b/src/flash/image/embedding/strategies/default.py @@ -74,6 +74,7 @@ def default(head: Optional[str] = None, loss_fn: Optional[str] = None, **kwargs) """Return `(None, None, [])` as loss function, head and hooks. Because default strategy only support prediction. + """ if head is not None: warnings.warn(f"default strategy has no heads. So given head({head}) is ignored.") diff --git a/src/flash/image/embedding/vissl/transforms/multicrop.py b/src/flash/image/embedding/vissl/transforms/multicrop.py index cddd57d9f7..357de4e639 100644 --- a/src/flash/image/embedding/vissl/transforms/multicrop.py +++ b/src/flash/image/embedding/vissl/transforms/multicrop.py @@ -44,6 +44,7 @@ class StandardMultiCropSSLTransform(InputTransform): gaussian_blur (bool): Specifies if the transforms' composition has Gaussian Blur jitter_strength (float): Specify the coefficient for color jitter transform normalize (Optional): Normalize transform from torchvision with params set according to the dataset + """ total_num_crops: int = 2 diff --git a/src/flash/image/embedding/vissl/transforms/utilities.py b/src/flash/image/embedding/vissl/transforms/utilities.py index e8658fecf4..915ee084a7 100644 --- a/src/flash/image/embedding/vissl/transforms/utilities.py +++ b/src/flash/image/embedding/vissl/transforms/utilities.py @@ -33,6 +33,7 @@ def multicrop_collate_fn(samples): """Multi-crop collate function for VISSL integration. Run custom collate on a single key since VISSL transforms affect only DataKeys.INPUT + """ result = vissl_collate_helper(samples) @@ -55,6 +56,7 @@ def simclr_collate_fn(samples): """Multi-crop collate function for VISSL integration. Run custom collate on a single key since VISSL transforms affect only DataKeys.INPUT + """ result = vissl_collate_helper(samples) diff --git a/src/flash/image/face_detection/input_transform.py b/src/flash/image/face_detection/input_transform.py index 9a889bd414..83d8c09747 100644 --- a/src/flash/image/face_detection/input_transform.py +++ b/src/flash/image/face_detection/input_transform.py @@ -33,6 +33,7 @@ def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence """Collate function from fastface. Organizes individual elements in a batch, calls prepare_batch from fastface and prepares the targets. + """ samples = {key: [sample[key] for sample in samples] for key in samples[0]} diff --git a/src/flash/pointcloud/detection/open3d_ml/app.py b/src/flash/pointcloud/detection/open3d_ml/app.py index 07f755f1bc..ea7c14db92 100644 --- a/src/flash/pointcloud/detection/open3d_ml/app.py +++ b/src/flash/pointcloud/detection/open3d_ml/app.py @@ -41,6 +41,7 @@ def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768 indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4]. width: The width of the visualization window. height: The height of the visualization window. + """ # Setup the labels lut = LabelLUT() diff --git a/src/flash/pointcloud/segmentation/open3d_ml/app.py b/src/flash/pointcloud/segmentation/open3d_ml/app.py index 7d3eab3f23..dc22623980 100644 --- a/src/flash/pointcloud/segmentation/open3d_ml/app.py +++ b/src/flash/pointcloud/segmentation/open3d_ml/app.py @@ -44,6 +44,7 @@ def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768 indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4]. width: The width of the visualization window. height: The height of the visualization window. + """ # Setup the labels lut = LabelLUT() diff --git a/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py index 6f7a4fcc53..a55acb196a 100644 --- a/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py +++ b/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py @@ -134,6 +134,7 @@ def get_label_to_names(self): Returns: A dict where keys are label numbers and values are the corresponding names. + """ return self.meta["label_to_names"] diff --git a/src/flash/template/classification/data.py b/src/flash/template/classification/data.py index ce0318f758..9b6ebdeba9 100644 --- a/src/flash/template/classification/data.py +++ b/src/flash/template/classification/data.py @@ -90,6 +90,7 @@ def predict_load_data(self, data: Bunch) -> Sequence[Dict[str, Any]]: Returns: A sequence of samples / sample metadata. + """ return super().load_data(data.data) diff --git a/src/flash/text/seq2seq/core/model.py b/src/flash/text/seq2seq/core/model.py index 8a00672c78..e988258f69 100644 --- a/src/flash/text/seq2seq/core/model.py +++ b/src/flash/text/seq2seq/core/model.py @@ -83,6 +83,7 @@ class Seq2SeqTask(Task): learning_rate: Learning rate to use for training, defaults to `3e-4` num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training + """ required_extras: str = "text" diff --git a/src/flash/video/classification/model.py b/src/flash/video/classification/model.py index 578312cad7..f324474e0d 100644 --- a/src/flash/video/classification/model.py +++ b/src/flash/video/classification/model.py @@ -62,6 +62,7 @@ class VideoClassifier(ClassificationTask): head: either a `nn.Module` or a callable function that converts the features extrated from the backbone into class log probabilities (assuming default loss function). If `None`, will default to using a single linear layer. + """ backbones: FlashRegistry = _VIDEO_CLASSIFIER_BACKBONES diff --git a/src/flash/video/classification/utils.py b/src/flash/video/classification/utils.py index 1c8cd2526e..90951ee06b 100644 --- a/src/flash/video/classification/utils.py +++ b/src/flash/video/classification/utils.py @@ -50,6 +50,7 @@ def __next__(self) -> dict: 'video_label': 'video_index': , } + """ if not self._video_sampler_iter: # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned. diff --git a/tests/core/serve/test_dag/test_order.py b/tests/core/serve/test_dag/test_order.py index 220a4593fd..b22c81a8c5 100644 --- a/tests/core/serve/test_dag/test_order.py +++ b/tests/core/serve/test_dag/test_order.py @@ -47,6 +47,7 @@ def test_avoid_broker_nodes(abcde): a0 a1 a0 should be run before a1 + """ a, b, c, d, e = abcde dsk = { @@ -100,6 +101,7 @@ def test_base_of_reduce_preferred(abcde): c We really want to run b0 quickly + """ a, b, c, d, e = abcde dsk = {(a, i): (f, (a, i - 1), (b, i)) for i in [1, 2, 3]} @@ -201,6 +203,7 @@ def test_deep_bases_win_over_dependents(abcde): b c | / \ | / e d + """ a, b, c, d, e = abcde dsk = {a: (f, b, c, d), b: (f, d, e), c: (f, d), d: 1, e: 2} @@ -297,6 +300,7 @@ def test_run_smaller_sections(abcde): a c e cc Prefer to run acb first because then we can get that out of the way + """ a, b, c, d, e = abcde aa, bb, cc, dd = (x * 2 for x in [a, b, c, d]) @@ -389,6 +393,7 @@ def test_nearest_neighbor(abcde): Want to finish off a local group before moving on. This is difficult because all groups are connected. + """ a, b, c, _, _ = abcde a1, a2, a3, a4, a5, a6, a7, a8, a9 = (a + i for i in "123456789") @@ -527,6 +532,7 @@ def test_map_overlap(abcde): e1 e2 e5 Want to finish b1 before we start on e5 + """ a, b, c, d, e = abcde dsk = { @@ -697,6 +703,7 @@ def test_switching_dependents(abcde): This test is pretty specific to how `order` is implemented and is intended to increase code coverage. + """ a, b, c, d, e = abcde dsk = { diff --git a/tests/core/serve/test_gridbase_validations.py b/tests/core/serve/test_gridbase_validations.py index 595cfc4528..6ae23f7a80 100644 --- a/tests/core/serve/test_gridbase_validations.py +++ b/tests/core/serve/test_gridbase_validations.py @@ -175,6 +175,7 @@ def test_ModelComponent_raises_if_exposed_input_keys_differ_from_decorated_metho This is noted because it differs from some other metaclass validations which will raise an exception at class definition time. + """ from tests.core.serve.models import ClassificationInference @@ -198,6 +199,7 @@ def test_ModelComponent_raises_if_config_is_empty_dict(lightning_squeezenet1_1_o This is noted because it differs from some other metaclass validations which will raise an exception at class definition time. + """ class ConfigComponent(ModelComponent): @@ -218,6 +220,7 @@ def test_ModelComponent_raises_if_model_is_empty_iterable(): This is noted because it differs from some other metaclass validations which will raise an exception at class definition time. + """ class ConfigComponent(ModelComponent): diff --git a/tests/core/utilities/test_embedder.py b/tests/core/utilities/test_embedder.py index 082d7068cc..20f6792b29 100644 --- a/tests/core/utilities/test_embedder.py +++ b/tests/core/utilities/test_embedder.py @@ -63,6 +63,7 @@ def test_embedder_scaling_overhead(): 200 layer model. Note that this bound is intentionally high in an effort to reduce the flakiness of the test. + """ shallow_embedder = Embedder(NLayerModel(3), "backbone.2") @@ -91,6 +92,7 @@ def test_embedder_raising_overhead(): execute the model without the embedder. Note that this bound is intentionally high in an effort to reduce the flakiness of the test. + """ model = NLayerModel(10) embedder = Embedder(model, "output") diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index 96505cde7a..14e5f51d9b 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -33,6 +33,7 @@ def training_step(...): model = BaseTestModel() model.training_epoch_end = None + """ super().__init__() self.layer = torch.nn.Linear(32, 2) diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index 8092d5e7be..fc60b075ce 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -75,6 +75,7 @@ def temp_encoded_video(num_frames: int, fps: int, height=10, width=10, prefix=No """Creates a temporary lossless, mp4 video with synthetic content. Uses a context which deletes the video after exit. + """ # Lossless options. video_codec = "libx264rgb" @@ -93,6 +94,7 @@ def mock_video_data_frame(): Returns a labeled video file which points to this mock encoded video dataset, the ordered label and videos tuples and the video duration in seconds. + """ num_frames = 10 fps = 5 @@ -127,6 +129,7 @@ def mock_encoded_video_dataset_folder(tmpdir): """Creates a temporary mock encoded video directory tree with 2 videos labeled 1, 2. Returns a directory that to this mock encoded video dataset and the video duration in seconds. + """ num_frames = 10 fps = 5