Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unhashed_package_root and PartialImport #157

Merged
merged 7 commits into from
Sep 11, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 90 additions & 5 deletions common/setups/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
"""

from __future__ import annotations
from typing import Any, Union, Optional, List
from typing import Any, Dict, Union, Optional, List
from types import FunctionType
import string
import sys
import textwrap

from sisyphus import tk
from sisyphus.hash import sis_hash_helper, short_hash
from sisyphus.delayed_ops import DelayedBase

from i6_core.util import uopen
from i6_core.util import uopen, instanciate_delayed


class SerializerObject(DelayedBase):
Expand Down Expand Up @@ -72,16 +73,24 @@ class Import(SerializerObject):
def __init__(
self,
code_object_path: Union[str, FunctionType, Any],
import_as: Optional[str] = None,
*,
unhashed_package_root: Optional[str] = None,
JackTemaki marked this conversation as resolved.
Show resolved Hide resolved
import_as: Optional[str] = None,
use_for_hash: bool = True,
ignore_import_as_for_hash: bool = False,
):
"""
:param code_object_path: e.g. `i6_experiments.users.username.my_rc_files.SomeNiceASRModel`.
:param code_object_path: e.g.`i6_experiments.users.username.some_experiment.pytorch_networks.SomeNiceASRModel`.
This can be the object itself, e.g. a function or a class. Then it will use __qualname__ and __module__.
:param unhashed_package_root: The root path to a package, from where relatives paths will be hashed.
Recommended is to use the root folder of an experiment module. E.g.:
`i6_experiments.users.username.some_experiment`
which could be retrieved via `__package__` from a module in the root of the `some_experiment` folder.
In case one wants to avoid hash conflicts this might cause, passing an `ExplicitHash` object to the
same collection as the import is possible.
:param import_as: if given, the code object will be imported as this name
:param use_for_hash:
:param use_for_hash: if False, this import is not hashed when passed to a Collection/Serializer
:param ignore_import_as_for_hash: do not hash `import_as` if set
"""
super().__init__()
if not isinstance(code_object_path, str):
Expand All @@ -96,6 +105,14 @@ def __init__(
self.object_name = self.code_object.split(".")[-1]
self.module = ".".join(self.code_object.split(".")[:-1])
self.package = ".".join(self.code_object.split(".")[:-2])
JackTemaki marked this conversation as resolved.
Show resolved Hide resolved

if unhashed_package_root:
if not self.code_object.startswith(unhashed_package_root):
raise ValueError(
f"unhashed_package_root: {unhashed_package_root} is not a prefix of {self.code_object}"
)
self.code_object = self.code_object[len(unhashed_package_root) :]

self.import_as = import_as
self.use_for_hash = use_for_hash
self.ignore_import_as_for_hash = ignore_import_as_for_hash
Expand All @@ -112,6 +129,74 @@ def _sis_hash(self):
return sis_hash_helper(self.code_object)


class PartialImport(Import):
"""
Like Import, but for partial callables where certain parameters are given fixed and are hashed.
JackTemaki marked this conversation as resolved.
Show resolved Hide resolved
"""

TEMPLATE = textwrap.dedent(
"""\
${OBJECT_NAME} = __import__("functools").partial(
__import__("${IMPORT_PATH}", fromlist=["${IMPORT_NAME}"]).${IMPORT_NAME},
**${KWARGS}
)
"""
JackTemaki marked this conversation as resolved.
Show resolved Hide resolved
)

def __init__(
self,
*,
code_object_path: Union[str, FunctionType, Any],
unhashed_package_root: str,
hashed_arguments: Dict[str, Any],
unhashed_arguments: Dict[str, Any],
import_as: Optional[str] = None,
use_for_hash: bool = True,
ignore_import_as_for_hash: bool = False,
):
"""
:param code_object_path: e.g.`i6_experiments.users.username.some_experiment.pytorch_networks.SomeNiceASRModel`.
This can be the object itself, e.g. a function or a class. Then it will use __qualname__ and __module__.
:param unhashed_package_root: The root path to a package, from where relatives paths will be hashed.
JackTemaki marked this conversation as resolved.
Show resolved Hide resolved
Recommended is to use the root folder of an experiment module. E.g.:
`i6_experiments.users.username.some_experiment`
which could be retrieved via `__package__` from a module in the root of the `some_experiment` folder.
In case one wants to avoid hash conflicts this might cause, passing an `ExplicitHash` object to the
same collection as the import is possible.
:param hashed_arguments: argument dictionary for addition partial arguments to set to the callable.
Will be serialized as dict into the config, so make sure to use only serializable/parseable content
:param unhashed_arguments: same as above, but does not influence the hash
:param import_as: if given, the code object will be imported as this name
:param use_for_hash: if False, this module is not hashed when passed to a Collection/Serializer
:param ignore_import_as_for_hash: do not hash `import_as` if set
"""

super().__init__(
code_object_path=code_object_path,
unhashed_package_root=unhashed_package_root,
import_as=import_as,
use_for_hash=use_for_hash,
ignore_import_as_for_hash=ignore_import_as_for_hash,
)
self.hashed_arguments = hashed_arguments
self.unhashed_arguments = unhashed_arguments

def get(self) -> str:
arguments = {**self.unhashed_arguments, **self.hashed_arguments}
return string.Template(self.TEMPLATE).substitute(
{
"KWARGS": str(instanciate_delayed(arguments)),
"IMPORT_PATH": self.module,
"IMPORT_NAME": self.object_name,
"OBJECT_NAME": self.import_as if self.import_as is not None else self.object_name,
}
)

def _sis_hash(self):
super_hash = super()._sis_hash()
return sis_hash_helper({"import": super_hash, "hashed_arguments": self.hashed_arguments})


class ExternalImport(SerializerObject):
"""
Import from e.g. a git repository. For imports within the recipes use "Import".
Expand Down