diff --git a/valohai/config.py b/valohai/config.py index 42709c7..d5e8117 100644 --- a/valohai/config.py +++ b/valohai/config.py @@ -1,5 +1,5 @@ import os -def is_running_in_valohai(): +def is_running_in_valohai() -> bool: return bool(os.environ.get("VH_JOB_ID")) diff --git a/valohai/inputs.py b/valohai/inputs.py index 9e9ac62..65a4f92 100644 --- a/valohai/inputs.py +++ b/valohai/inputs.py @@ -7,7 +7,7 @@ class Input: - def __init__(self, name: str): + def __init__(self, name: str) -> None: self.name = str(name) def paths( diff --git a/valohai/internals/compression.py b/valohai/internals/compression.py index 37c1e9e..a643907 100644 --- a/valohai/internals/compression.py +++ b/valohai/internals/compression.py @@ -60,12 +60,18 @@ def put(self, archive_name, source: Union[str, IO]): class ZipArchive(BaseArchive, zipfile.ZipFile): - def __init__(self, file, mode="r", *, compresslevel=1): + def __init__(self, file: str, mode: str = "r", *, compresslevel: int = 1) -> None: # Only Python 3.7+ has the compresslevel kwarg here super().__init__(file, mode, compression=zipfile.ZIP_STORED) self.compresslevel = compresslevel - def writestream(self, arcname, data, compress_type, compresslevel): + def writestream( + self, + arcname: str, + data: Union[str, bytes, IO[bytes]], + compress_type: int, + compresslevel: int, + ) -> None: # Like `writestr`, but also supports a stream (and doesn't support directories). zinfo = zipfile.ZipInfo(filename=arcname) zinfo.compress_type = compress_type @@ -82,7 +88,7 @@ def writestream(self, arcname, data, compress_type, compresslevel): shutil.copyfileobj(data, dest, 524288) assert zinfo.file_size - def put(self, archive_name, source: Union[str, IO]): + def put(self, archive_name: str, source: Union[str, IO]) -> None: compress_type = ( zipfile.ZIP_DEFLATED if guess_compressible(archive_name) @@ -103,7 +109,7 @@ def put(self, archive_name, source: Union[str, IO]): class TarArchive(BaseArchive, tarfile.TarFile): - def put(self, archive_name, source: Union[str, IO]): + def put(self, archive_name: str, source: Union[str, IO]) -> None: with contextlib.ExitStack() as es: if isinstance(source, str): size = os.stat(source).st_size @@ -119,7 +125,7 @@ def put(self, archive_name, source: Union[str, IO]): self.addfile(tarinfo, stream) -def open_archive(path: str): +def open_archive(path: str) -> BaseArchive: if path.endswith(".zip"): return ZipArchive(path, "w") elif path.endswith(".tar"): diff --git a/valohai/internals/download.py b/valohai/internals/download.py index 9502a2c..f1f5937 100644 --- a/valohai/internals/download.py +++ b/valohai/internals/download.py @@ -3,7 +3,7 @@ # TODO: This is close to valohai-local-run. Possibility to merge. -def download_url(url, path, force_download=False): +def download_url(url: str, path: str, force_download: bool = False) -> str: if not os.path.isfile(path) or force_download: try: import requests diff --git a/valohai/internals/files.py b/valohai/internals/files.py index 7048f86..d07183f 100644 --- a/valohai/internals/files.py +++ b/valohai/internals/files.py @@ -5,7 +5,7 @@ from typing import Set, Union -def set_file_read_only(path: str): +def set_file_read_only(path: str) -> None: os.chmod(path, S_IREAD | S_IRGRP | S_IROTH) diff --git a/valohai/internals/guid.py b/valohai/internals/guid.py index 2938ea7..21d6b59 100644 --- a/valohai/internals/guid.py +++ b/valohai/internals/guid.py @@ -4,7 +4,7 @@ _execution_guid = None -def get_execution_guid(): +def get_execution_guid() -> str: global _execution_guid if not _execution_guid: _execution_guid = f'{time.strftime("%Y%m%d-%H%M%S")}-{uuid.uuid4().hex[:6]}' diff --git a/valohai/internals/input_info.py b/valohai/internals/input_info.py index 7086b4e..e5cf35d 100644 --- a/valohai/internals/input_info.py +++ b/valohai/internals/input_info.py @@ -10,17 +10,25 @@ class FileInfo: - def __init__(self, *, name, uri, path, size, checksums): + def __init__( + self, + *, + name: str, + uri: Optional[str], + path: Optional[str], + size: Optional[int], + checksums: Optional[dict], + ) -> None: self.name = str(name) self.uri = str(uri) if uri else None self.checksums = checksums self.path = str(path) if path else None self.size = int(size) if size else None - def is_downloaded(self): + def is_downloaded(self) -> Optional[bool]: return self.path and os.path.isfile(self.path) - def download(self, path, force_download: bool = False): + def download(self, path: str, force_download: bool = False) -> None: self.path = download_url( self.uri, os.path.join(path, self.name), force_download ) @@ -31,13 +39,13 @@ class InputInfo: def __init__(self, files: Iterable[FileInfo]): self.files = list(files) - def is_downloaded(self): + def is_downloaded(self) -> bool: if not self.files: return False return all(f.is_downloaded() for f in self.files) - def download(self, path, force_download: bool = False): + def download(self, path: str, force_download: bool = False) -> None: for f in self.files: f.download(path, force_download) diff --git a/valohai/internals/merge.py b/valohai/internals/merge.py index 7259190..d1ddddf 100644 --- a/valohai/internals/merge.py +++ b/valohai/internals/merge.py @@ -1,7 +1,8 @@ import copy -from valohai_yaml.objs import Config, Step from valohai_yaml.objs.base import Item +from valohai_yaml.objs.config import Config +from valohai_yaml.objs.step import Step from valohai_yaml.utils.merge import merge_dicts, merge_simple from valohai.consts import DEFAULT_DOCKER_IMAGE diff --git a/valohai/internals/parsing.py b/valohai/internals/parsing.py index 0d44d63..35e6542 100644 --- a/valohai/internals/parsing.py +++ b/valohai/internals/parsing.py @@ -3,7 +3,7 @@ from typing import List -def is_module_function_call(node, module, function): +def is_module_function_call(node: ast.Call, module: str, function: str) -> bool: try: return node.func.attr == function and node.func.value.id == module except AttributeError: @@ -39,25 +39,25 @@ class PrepareParser(ast.NodeVisitor): valohai.prepare(parameters=get_parameters()) """ - def __init__(self): + def __init__(self) -> None: self.assignments = {} self.parameters = {} self.inputs = {} self.step = None self.image = None - def visit_Assign(self, node): + def visit_Assign(self, node: ast.Assign) -> None: try: self.assignments[node.targets[0].id] = ast.literal_eval(node.value) except ValueError: # We don't care about assignments that can't be literal_eval():ed pass - def visit_Call(self, node): + def visit_Call(self, node: ast.Call) -> None: if is_module_function_call(node, "valohai", "prepare"): self.process_valohai_prepare_call(node) - def process_valohai_prepare_call(self, node): + def process_valohai_prepare_call(self, node: ast.Call) -> None: self.step = "default" if hasattr(node, "keywords"): for key in node.keywords: @@ -70,7 +70,7 @@ def process_valohai_prepare_call(self, node): elif key.arg == "image": self.image = ast.literal_eval(key.value) - def process_default_inputs_arg(self, key): + def process_default_inputs_arg(self, key: ast.keyword) -> None: if isinstance(key.value, ast.Name) and key.value.id in self.assignments: self.inputs = { key: value if isinstance(value, List) else [value] @@ -84,7 +84,7 @@ def process_default_inputs_arg(self, key): else: raise NotImplementedError() - def process_default_parameters_arg(self, key): + def process_default_parameters_arg(self, key: ast.keyword) -> None: if isinstance(key.value, ast.Name) and key.value.id in self.assignments: self.parameters = self.assignments[key.value.id] elif isinstance(key.value, ast.Dict): diff --git a/valohai/internals/vfs.py b/valohai/internals/vfs.py index 0ee0951..3b30ad7 100644 --- a/valohai/internals/vfs.py +++ b/valohai/internals/vfs.py @@ -3,9 +3,9 @@ import shutil import tempfile from contextlib import ExitStack -from tarfile import TarFile, TarInfo +from tarfile import ExFileObject, TarFile, TarInfo from typing import IO, Optional, Union -from zipfile import ZipFile, ZipInfo +from zipfile import ZipExtFile, ZipFile, ZipInfo class File: @@ -14,7 +14,7 @@ class File: def open(self) -> io.BufferedReader: raise NotImplementedError("...") - def read(self): + def read(self) -> bytes: with self.open() as f: return f.read() @@ -35,12 +35,14 @@ class FileOnDisk(File): path: str dir_entry: Optional[os.DirEntry] - def __init__(self, name, path, dir_entry=None): + def __init__( + self, name: str, path: str, dir_entry: Optional[os.DirEntry] = None + ) -> None: self._name = name self.path = path self.dir_entry = dir_entry - def open(self): + def open(self) -> io.BufferedReader: return open(self.path, "rb") # noqa: SIM115 @property @@ -51,7 +53,7 @@ def name(self) -> str: class FileInContainer(File): _concrete_path: Optional[str] = None - def open_concrete(self, delete=True): + def open_concrete(self, delete: bool = True) -> IO[bytes]: if self._concrete_path and os.path.isfile(self._concrete_path): return open(self._concrete_path) # noqa: SIM115 tf = tempfile.NamedTemporaryFile(suffix=self.extension, delete=delete) @@ -60,7 +62,7 @@ def open_concrete(self, delete=True): self._concrete_path = tf.name return tf - def extract(self, destination: Union[str, IO]): + def extract(self, destination: Union[str, IO]) -> None: if isinstance(destination, str): destination = open(destination, "wb") # noqa: SIM115 should_close = True @@ -72,7 +74,7 @@ def extract(self, destination: Union[str, IO]): if should_close: destination.close() - def _do_extract(self, destination: IO): + def _do_extract(self, destination: IO) -> None: # if a file has a better idea how to write itself into an IO, this is the place with self.open() as f: shutil.copyfileobj(f, destination) @@ -82,12 +84,14 @@ class FileInZip(FileInContainer): zipfile: ZipFile zipinfo: ZipInfo - def __init__(self, parent_file, zipfile, zipinfo): + def __init__( + self, parent_file: FileOnDisk, zipfile: ZipFile, zipinfo: ZipInfo + ) -> None: self.parent_file = parent_file self.zipfile = zipfile self.zipinfo = zipinfo - def open(self): + def open(self) -> ZipExtFile: return self.zipfile.open(self.zipinfo, "r") @property @@ -107,12 +111,14 @@ class FileInTar(FileInContainer): tarfile: TarFile tarinfo: TarInfo - def __init__(self, parent_file, tarfile, tarinfo): + def __init__( + self, parent_file: FileOnDisk, tarfile: TarFile, tarinfo: TarInfo + ) -> None: self.parent_file = parent_file self.tarfile = tarfile self.tarinfo = tarinfo - def open(self): + def open(self) -> ExFileObject: return self.tarfile.extractfile(self.tarinfo) @property @@ -144,14 +150,14 @@ def find_files_in_tar(vr: "VFS", df: FileOnDisk) -> None: class VFS: - def __init__(self): + def __init__(self) -> None: self.files = [] self.exit_stack = ExitStack() - def __enter__(self): + def __enter__(self) -> "VFS": return self - def __exit__(self, *exc_details): + def __exit__(self, *exc_details) -> None: self.exit_stack.__exit__(*exc_details) @@ -160,8 +166,8 @@ def add_disk_file( name: str, path: str, dir_entry: Optional[os.DirEntry] = None, - process_archives=False, -): + process_archives: bool = False, +) -> None: disk_file = FileOnDisk(name=name, path=path, dir_entry=dir_entry) if process_archives: extension = disk_file.extension.lower() @@ -174,7 +180,7 @@ def add_disk_file( vfs.files.append(disk_file) -def find_files(vfs: VFS, root: str, *, process_archives: bool): +def find_files(vfs: VFS, root: str, *, process_archives: bool) -> None: dent: os.DirEntry def _walk(path): diff --git a/valohai/internals/yaml.py b/valohai/internals/yaml.py index ef9f318..0164e4c 100644 --- a/valohai/internals/yaml.py +++ b/valohai/internals/yaml.py @@ -1,8 +1,10 @@ import os from typing import Any, Dict -from valohai_yaml.objs import Config, Parameter, Step +from valohai_yaml.objs.config import Config from valohai_yaml.objs.input import Input, KeepDirectories +from valohai_yaml.objs.parameter import Parameter +from valohai_yaml.objs.step import Step from valohai.consts import DEFAULT_DOCKER_IMAGE from valohai.internals.parsing import parse @@ -17,7 +19,7 @@ def generate_step( step: str, image: str, parameters: ParameterDict, - inputs: InputDict + inputs: InputDict, ) -> Step: config_step = Step( name=step, @@ -50,7 +52,7 @@ def generate_config( step: str, image: str, parameters: ParameterDict, - inputs: InputDict + inputs: InputDict, ) -> Config: step = generate_step( relative_source_path=relative_source_path, @@ -85,7 +87,7 @@ def get_source_relative_path(source_path: str, config_path: str) -> str: return os.path.join(relative_source_dir, os.path.basename(source_path)) -def parse_config_from_source(source_path: str, config_path: str): +def parse_config_from_source(source_path: str, config_path: str) -> Config: with open(source_path) as source_file: parsed = parse(source_file.read()) if not parsed.step: diff --git a/valohai/metadata.py b/valohai/metadata.py index db35944..8143010 100644 --- a/valohai/metadata.py +++ b/valohai/metadata.py @@ -1,20 +1,21 @@ import json +from typing import Any _supported_types = [int, float] class Logger: - def __init__(self): + def __init__(self) -> None: self.partial_logs = {} - def __enter__(self): + def __enter__(self) -> "Logger": self.partial_logs = {} return self - def __exit__(self, type, value, traceback): + def __exit__(self, type, value, traceback) -> None: self.flush() - def log(self, *args, **kwargs): + def log(self, *args, **kwargs) -> None: """Log a single name/value pair to be flushed into standard output later as batch. For a repeating iteration like a machine learning training loop, Valohai expects @@ -54,7 +55,7 @@ def log(self, *args, **kwargs): for key, value in kwargs.items(): self._serialize(key, value) - def flush(self): + def flush(self) -> None: """Flush all the partial logs into standard as a batch. For a repeating iteration like a machine learning training loop, Valohai expects @@ -74,7 +75,7 @@ def flush(self): print(f"\n{json.dumps(self.partial_logs, default=str)}") # noqa self.partial_logs.clear() - def _serialize(self, name, value): + def _serialize(self, name: str, value: Any) -> None: self.partial_logs.update({str(name): value}) diff --git a/valohai/outputs.py b/valohai/outputs.py index 9ec13dd..f662447 100644 --- a/valohai/outputs.py +++ b/valohai/outputs.py @@ -14,7 +14,7 @@ class Output: - def __init__(self, name: str = ""): + def __init__(self, name: str = "") -> None: self.name = str(name) def path(self, filename: str, makedirs: bool = True) -> str: @@ -52,7 +52,7 @@ def path(self, filename: str, makedirs: bool = True) -> str: return path - def live_upload(self, filename: str): + def live_upload(self, filename: str) -> None: for file_path in glob.glob(get_glob_pattern(self.path(filename))): set_file_read_only(file_path) diff --git a/valohai/parameters.py b/valohai/parameters.py index 690e199..7c1be5b 100644 --- a/valohai/parameters.py +++ b/valohai/parameters.py @@ -1,10 +1,9 @@ from valohai.internals import global_state - -from .internals.parameters import load_parameter, supported_types +from valohai.internals.parameters import load_parameter, supported_types class Parameter: - def __init__(self, name: str, default: supported_types = None): + def __init__(self, name: str, default: supported_types = None) -> None: self.name = str(name) self.default = default diff --git a/valohai/utils.py b/valohai/utils.py index 68fe88d..22b8971 100644 --- a/valohai/utils.py +++ b/valohai/utils.py @@ -16,7 +16,7 @@ def prepare( default_parameters: Optional[dict] = None, default_inputs: Optional[dict] = None, image: str = None, -): +) -> None: """Define the name of the step and it's required inputs, parameters and Docker image Has dual purpose: diff --git a/valohai/yaml.py b/valohai/yaml.py index 9382c33..e6a8ef8 100644 --- a/valohai/yaml.py +++ b/valohai/yaml.py @@ -1,7 +1,7 @@ from collections import OrderedDict import yaml -from valohai_yaml.objs import Config +from valohai_yaml.objs.config import Config # https://stackoverflow.com/questions/42518067/how-to-use-ordereddict-as-an-input-in-yaml-dump-or-yaml-safe-dump yaml.add_representer( @@ -12,7 +12,7 @@ ) -def config_to_yaml(config: Config): +def config_to_yaml(config: Config) -> str: """Serialize Valohai Config to YAML :param config: valohai_yaml.objs.Config object