Skip to content

Commit

Permalink
Add type hints using Monkeytype
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Apr 19, 2021
1 parent 0222f63 commit 9bf5b92
Show file tree
Hide file tree
Showing 16 changed files with 82 additions and 59 deletions.
2 changes: 1 addition & 1 deletion valohai/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os


def is_running_in_valohai():
def is_running_in_valohai() -> bool:
return bool(os.environ.get("VH_JOB_ID"))
2 changes: 1 addition & 1 deletion valohai/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class Input:
def __init__(self, name: str):
def __init__(self, name: str) -> None:
self.name = str(name)

def paths(
Expand Down
16 changes: 11 additions & 5 deletions valohai/internals/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,18 @@ def put(self, archive_name, source: Union[str, IO]):


class ZipArchive(BaseArchive, zipfile.ZipFile):
def __init__(self, file, mode="r", *, compresslevel=1):
def __init__(self, file: str, mode: str = "r", *, compresslevel: int = 1) -> None:
# Only Python 3.7+ has the compresslevel kwarg here
super().__init__(file, mode, compression=zipfile.ZIP_STORED)
self.compresslevel = compresslevel

def writestream(self, arcname, data, compress_type, compresslevel):
def writestream(
self,
arcname: str,
data: Union[str, bytes, IO[bytes]],
compress_type: int,
compresslevel: int,
) -> None:
# Like `writestr`, but also supports a stream (and doesn't support directories).
zinfo = zipfile.ZipInfo(filename=arcname)
zinfo.compress_type = compress_type
Expand All @@ -82,7 +88,7 @@ def writestream(self, arcname, data, compress_type, compresslevel):
shutil.copyfileobj(data, dest, 524288)
assert zinfo.file_size

def put(self, archive_name, source: Union[str, IO]):
def put(self, archive_name: str, source: Union[str, IO]) -> None:
compress_type = (
zipfile.ZIP_DEFLATED
if guess_compressible(archive_name)
Expand All @@ -103,7 +109,7 @@ def put(self, archive_name, source: Union[str, IO]):


class TarArchive(BaseArchive, tarfile.TarFile):
def put(self, archive_name, source: Union[str, IO]):
def put(self, archive_name: str, source: Union[str, IO]) -> None:
with contextlib.ExitStack() as es:
if isinstance(source, str):
size = os.stat(source).st_size
Expand All @@ -119,7 +125,7 @@ def put(self, archive_name, source: Union[str, IO]):
self.addfile(tarinfo, stream)


def open_archive(path: str):
def open_archive(path: str) -> BaseArchive:
if path.endswith(".zip"):
return ZipArchive(path, "w")
elif path.endswith(".tar"):
Expand Down
2 changes: 1 addition & 1 deletion valohai/internals/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


# TODO: This is close to valohai-local-run. Possibility to merge.
def download_url(url, path, force_download=False):
def download_url(url: str, path: str, force_download: bool = False) -> str:
if not os.path.isfile(path) or force_download:
try:
import requests
Expand Down
2 changes: 1 addition & 1 deletion valohai/internals/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Set, Union


def set_file_read_only(path: str):
def set_file_read_only(path: str) -> None:
os.chmod(path, S_IREAD | S_IRGRP | S_IROTH)


Expand Down
2 changes: 1 addition & 1 deletion valohai/internals/guid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
_execution_guid = None


def get_execution_guid():
def get_execution_guid() -> str:
global _execution_guid
if not _execution_guid:
_execution_guid = f'{time.strftime("%Y%m%d-%H%M%S")}-{uuid.uuid4().hex[:6]}'
Expand Down
18 changes: 13 additions & 5 deletions valohai/internals/input_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,25 @@


class FileInfo:
def __init__(self, *, name, uri, path, size, checksums):
def __init__(
self,
*,
name: str,
uri: Optional[str],
path: Optional[str],
size: Optional[int],
checksums: Optional[dict],
) -> None:
self.name = str(name)
self.uri = str(uri) if uri else None
self.checksums = checksums
self.path = str(path) if path else None
self.size = int(size) if size else None

def is_downloaded(self):
def is_downloaded(self) -> Optional[bool]:
return self.path and os.path.isfile(self.path)

def download(self, path, force_download: bool = False):
def download(self, path: str, force_download: bool = False) -> None:
self.path = download_url(
self.uri, os.path.join(path, self.name), force_download
)
Expand All @@ -31,13 +39,13 @@ class InputInfo:
def __init__(self, files: Iterable[FileInfo]):
self.files = list(files)

def is_downloaded(self):
def is_downloaded(self) -> bool:
if not self.files:
return False

return all(f.is_downloaded() for f in self.files)

def download(self, path, force_download: bool = False):
def download(self, path: str, force_download: bool = False) -> None:
for f in self.files:
f.download(path, force_download)

Expand Down
3 changes: 2 additions & 1 deletion valohai/internals/merge.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy

from valohai_yaml.objs import Config, Step
from valohai_yaml.objs.base import Item
from valohai_yaml.objs.config import Config
from valohai_yaml.objs.step import Step
from valohai_yaml.utils.merge import merge_dicts, merge_simple

from valohai.consts import DEFAULT_DOCKER_IMAGE
Expand Down
14 changes: 7 additions & 7 deletions valohai/internals/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List


def is_module_function_call(node, module, function):
def is_module_function_call(node: ast.Call, module: str, function: str) -> bool:
try:
return node.func.attr == function and node.func.value.id == module
except AttributeError:
Expand Down Expand Up @@ -39,25 +39,25 @@ class PrepareParser(ast.NodeVisitor):
valohai.prepare(parameters=get_parameters())
"""

def __init__(self):
def __init__(self) -> None:
self.assignments = {}
self.parameters = {}
self.inputs = {}
self.step = None
self.image = None

def visit_Assign(self, node):
def visit_Assign(self, node: ast.Assign) -> None:
try:
self.assignments[node.targets[0].id] = ast.literal_eval(node.value)
except ValueError:
# We don't care about assignments that can't be literal_eval():ed
pass

def visit_Call(self, node):
def visit_Call(self, node: ast.Call) -> None:
if is_module_function_call(node, "valohai", "prepare"):
self.process_valohai_prepare_call(node)

def process_valohai_prepare_call(self, node):
def process_valohai_prepare_call(self, node: ast.Call) -> None:
self.step = "default"
if hasattr(node, "keywords"):
for key in node.keywords:
Expand All @@ -70,7 +70,7 @@ def process_valohai_prepare_call(self, node):
elif key.arg == "image":
self.image = ast.literal_eval(key.value)

def process_default_inputs_arg(self, key):
def process_default_inputs_arg(self, key: ast.keyword) -> None:
if isinstance(key.value, ast.Name) and key.value.id in self.assignments:
self.inputs = {
key: value if isinstance(value, List) else [value]
Expand All @@ -84,7 +84,7 @@ def process_default_inputs_arg(self, key):
else:
raise NotImplementedError()

def process_default_parameters_arg(self, key):
def process_default_parameters_arg(self, key: ast.keyword) -> None:
if isinstance(key.value, ast.Name) and key.value.id in self.assignments:
self.parameters = self.assignments[key.value.id]
elif isinstance(key.value, ast.Dict):
Expand Down
42 changes: 24 additions & 18 deletions valohai/internals/vfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import shutil
import tempfile
from contextlib import ExitStack
from tarfile import TarFile, TarInfo
from tarfile import ExFileObject, TarFile, TarInfo
from typing import IO, Optional, Union
from zipfile import ZipFile, ZipInfo
from zipfile import ZipExtFile, ZipFile, ZipInfo


class File:
Expand All @@ -14,7 +14,7 @@ class File:
def open(self) -> io.BufferedReader:
raise NotImplementedError("...")

def read(self):
def read(self) -> bytes:
with self.open() as f:
return f.read()

Expand All @@ -35,12 +35,14 @@ class FileOnDisk(File):
path: str
dir_entry: Optional[os.DirEntry]

def __init__(self, name, path, dir_entry=None):
def __init__(
self, name: str, path: str, dir_entry: Optional[os.DirEntry] = None
) -> None:
self._name = name
self.path = path
self.dir_entry = dir_entry

def open(self):
def open(self) -> io.BufferedReader:
return open(self.path, "rb") # noqa: SIM115

@property
Expand All @@ -51,7 +53,7 @@ def name(self) -> str:
class FileInContainer(File):
_concrete_path: Optional[str] = None

def open_concrete(self, delete=True):
def open_concrete(self, delete: bool = True) -> IO[bytes]:
if self._concrete_path and os.path.isfile(self._concrete_path):
return open(self._concrete_path) # noqa: SIM115
tf = tempfile.NamedTemporaryFile(suffix=self.extension, delete=delete)
Expand All @@ -60,7 +62,7 @@ def open_concrete(self, delete=True):
self._concrete_path = tf.name
return tf

def extract(self, destination: Union[str, IO]):
def extract(self, destination: Union[str, IO]) -> None:
if isinstance(destination, str):
destination = open(destination, "wb") # noqa: SIM115
should_close = True
Expand All @@ -72,7 +74,7 @@ def extract(self, destination: Union[str, IO]):
if should_close:
destination.close()

def _do_extract(self, destination: IO):
def _do_extract(self, destination: IO) -> None:
# if a file has a better idea how to write itself into an IO, this is the place
with self.open() as f:
shutil.copyfileobj(f, destination)
Expand All @@ -82,12 +84,14 @@ class FileInZip(FileInContainer):
zipfile: ZipFile
zipinfo: ZipInfo

def __init__(self, parent_file, zipfile, zipinfo):
def __init__(
self, parent_file: FileOnDisk, zipfile: ZipFile, zipinfo: ZipInfo
) -> None:
self.parent_file = parent_file
self.zipfile = zipfile
self.zipinfo = zipinfo

def open(self):
def open(self) -> ZipExtFile:
return self.zipfile.open(self.zipinfo, "r")

@property
Expand All @@ -107,12 +111,14 @@ class FileInTar(FileInContainer):
tarfile: TarFile
tarinfo: TarInfo

def __init__(self, parent_file, tarfile, tarinfo):
def __init__(
self, parent_file: FileOnDisk, tarfile: TarFile, tarinfo: TarInfo
) -> None:
self.parent_file = parent_file
self.tarfile = tarfile
self.tarinfo = tarinfo

def open(self):
def open(self) -> ExFileObject:
return self.tarfile.extractfile(self.tarinfo)

@property
Expand Down Expand Up @@ -144,14 +150,14 @@ def find_files_in_tar(vr: "VFS", df: FileOnDisk) -> None:


class VFS:
def __init__(self):
def __init__(self) -> None:
self.files = []
self.exit_stack = ExitStack()

def __enter__(self):
def __enter__(self) -> "VFS":
return self

def __exit__(self, *exc_details):
def __exit__(self, *exc_details) -> None:
self.exit_stack.__exit__(*exc_details)


Expand All @@ -160,8 +166,8 @@ def add_disk_file(
name: str,
path: str,
dir_entry: Optional[os.DirEntry] = None,
process_archives=False,
):
process_archives: bool = False,
) -> None:
disk_file = FileOnDisk(name=name, path=path, dir_entry=dir_entry)
if process_archives:
extension = disk_file.extension.lower()
Expand All @@ -174,7 +180,7 @@ def add_disk_file(
vfs.files.append(disk_file)


def find_files(vfs: VFS, root: str, *, process_archives: bool):
def find_files(vfs: VFS, root: str, *, process_archives: bool) -> None:
dent: os.DirEntry

def _walk(path):
Expand Down
10 changes: 6 additions & 4 deletions valohai/internals/yaml.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from typing import Any, Dict

from valohai_yaml.objs import Config, Parameter, Step
from valohai_yaml.objs.config import Config
from valohai_yaml.objs.input import Input, KeepDirectories
from valohai_yaml.objs.parameter import Parameter
from valohai_yaml.objs.step import Step

from valohai.consts import DEFAULT_DOCKER_IMAGE
from valohai.internals.parsing import parse
Expand All @@ -17,7 +19,7 @@ def generate_step(
step: str,
image: str,
parameters: ParameterDict,
inputs: InputDict
inputs: InputDict,
) -> Step:
config_step = Step(
name=step,
Expand Down Expand Up @@ -50,7 +52,7 @@ def generate_config(
step: str,
image: str,
parameters: ParameterDict,
inputs: InputDict
inputs: InputDict,
) -> Config:
step = generate_step(
relative_source_path=relative_source_path,
Expand Down Expand Up @@ -85,7 +87,7 @@ def get_source_relative_path(source_path: str, config_path: str) -> str:
return os.path.join(relative_source_dir, os.path.basename(source_path))


def parse_config_from_source(source_path: str, config_path: str):
def parse_config_from_source(source_path: str, config_path: str) -> Config:
with open(source_path) as source_file:
parsed = parse(source_file.read())
if not parsed.step:
Expand Down
Loading

0 comments on commit 9bf5b92

Please sign in to comment.