-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Based on: - https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/saved_model/__init__.pyi - https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/types/experimental.pyi
- Loading branch information
1 parent
8a9dcb1
commit 56bb53a
Showing
7 changed files
with
206 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from _typeshed import Incomplete | ||
|
||
from tensorflow.python.trackable.base import Trackable | ||
|
||
class _ResourceMetaclass(type): ... | ||
|
||
# Internal type that is commonly used as a base class | ||
# it is needed for the public signatures of some APIs. | ||
class CapturableResource(Trackable, metaclass=_ResourceMetaclass): ... | ||
|
||
def __getattr__(name: str) -> Incomplete: ... |
3 changes: 3 additions & 0 deletions
3
stubs/tensorflow/tensorflow/python/training/tracking/autotrackable.pyi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from _typeshed import Incomplete | ||
|
||
AutoTrackable = Incomplete |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from collections.abc import Mapping, Sequence | ||
from pathlib import Path | ||
from typing import Any, Generic, Literal, TypeVar | ||
from typing_extensions import ParamSpec, TypeAlias | ||
|
||
import tensorflow as tf | ||
from tensorflow.python.training.tracking.autotrackable import AutoTrackable | ||
from tensorflow.saved_model.experimental import VariablePolicy | ||
from tensorflow.types.experimental import ConcreteFunction, GenericFunction | ||
|
||
_P = ParamSpec("_P") | ||
_R = TypeVar("_R", covariant=True) | ||
|
||
class Asset: | ||
@property | ||
def asset_path(self) -> tf.Tensor: ... | ||
def __init__(self, path: str | Path | tf.Tensor) -> None: ... | ||
|
||
class LoadOptions: | ||
allow_partial_checkpoint: bool | ||
experimental_io_device: str | None | ||
experimental_skip_checkpoint: bool | ||
experimental_variable_policy: VariablePolicy | None | ||
experimental_load_function_aliases: bool | ||
|
||
def __init__( | ||
self, | ||
allow_partial_checkpoint: bool = False, | ||
experimental_io_device: str | None = None, | ||
experimental_skip_checkpoint: bool = False, | ||
experimental_variable_policy: ( | ||
VariablePolicy | Literal["expand_distributed_variables", "save_variable_devices"] | None | ||
) = None, | ||
experimental_load_function_aliases: bool = False, | ||
) -> None: ... | ||
|
||
class SaveOptions: | ||
__slots__ = ( | ||
"namespace_whitelist", | ||
"save_debug_info", | ||
"function_aliases", | ||
"experimental_io_device", | ||
"experimental_variable_policy", | ||
"experimental_custom_gradients", | ||
"experimental_image_format", | ||
"experimental_skip_saver", | ||
) | ||
namespace_whitelist: list[str] | ||
save_debug_info: bool | ||
function_aliases: dict[str, tf.types.experimental.GenericFunction[..., object]] | ||
experimental_io_device: str | ||
experimental_variable_policy: VariablePolicy | ||
experimental_custom_gradients: bool | ||
experimental_image_format: bool | ||
experimental_skip_saver: bool | ||
def __init__( | ||
self, | ||
namespace_whitelist: list[str] | None = None, | ||
save_debug_info: bool = False, | ||
function_aliases: Mapping[str, tf.types.experimental.GenericFunction[..., object]] | None = None, | ||
experimental_io_device: str | None = None, | ||
experimental_variable_policy: str | VariablePolicy | None = None, | ||
experimental_custom_gradients: bool = True, | ||
experimental_image_format: bool = False, | ||
experimental_skip_saver: bool = False, | ||
) -> None: ... | ||
|
||
def contains_saved_model(export_dir: str | Path) -> bool: ... | ||
|
||
class _LoadedAttributes(Generic[_P, _R]): | ||
signatures: Mapping[str, ConcreteFunction[_P, _R]] | ||
|
||
class _LoadedModel(AutoTrackable, _LoadedAttributes[_P, _R]): | ||
variables: list[tf.Variable] | ||
trainable_variables: list[tf.Variable] | ||
# TF1 model artifact specific | ||
graph: tf.Graph | ||
|
||
def load( | ||
export_dir: str, tags: str | Sequence[str] | None = None, options: LoadOptions | None = None | ||
) -> _LoadedModel[..., Any]: ... | ||
|
||
_TF_Function: TypeAlias = ConcreteFunction[..., object] | GenericFunction[..., object] | ||
|
||
def save( | ||
obj: tf.Module, | ||
export_dir: str, | ||
signatures: _TF_Function | Mapping[str, _TF_Function] | None = None, | ||
options: SaveOptions | None = None, | ||
) -> None: ... | ||
|
||
ASSETS_DIRECTORY: str = "assets" | ||
ASSETS_KEY: str = "saved_model_assets" | ||
CLASSIFY_INPUTS: str = "inputs" | ||
CLASSIFY_METHOD_NAME: str = "tensorflow/serving/classify" | ||
CLASSIFY_OUTPUT_CLASSES: str = "classes" | ||
CLASSIFY_OUTPUT_SCORES: str = "scores" | ||
DEBUG_DIRECTORY: str = "debug" | ||
DEBUG_INFO_FILENAME_PB: str = "saved_model_debug_info.pb" | ||
DEFAULT_SERVING_SIGNATURE_DEF_KEY: str = "serving_default" | ||
GPU: str = "gpu" | ||
PREDICT_INPUTS: str = "inputs" | ||
PREDICT_METHOD_NAME: str = "tensorflow/serving/predict" | ||
PREDICT_OUTPUTS: str = "outputs" | ||
REGRESS_INPUTS: str = "inputs" | ||
REGRESS_METHOD_NAME: str = "tensorflow/serving/regress" | ||
REGRESS_OUTPUTS: str = "outputs" | ||
SAVED_MODEL_FILENAME_PB: str = "saved_model.pb" | ||
SAVED_MODEL_FILENAME_PBTXT: str = "saved_model.pbtxt" | ||
SAVED_MODEL_SCHEMA_VERSION: int = 1 | ||
SERVING: str = "serve" | ||
TPU: str = "tpu" | ||
TRAINING: str = "train" | ||
VARIABLES_DIRECTORY: str = "variables" | ||
VARIABLES_FILENAME: str = "variables" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from _typeshed import Incomplete | ||
from enum import Enum | ||
from typing_extensions import Self | ||
|
||
import tensorflow as tf | ||
from tensorflow._aliases import Integer, TensorValue | ||
from tensorflow.python.trackable.resource import CapturableResource | ||
|
||
class Fingerprint: | ||
saved_model_checksum: TensorValue | None | ||
graph_def_program_hash: TensorValue | None = None | ||
signature_def_hash: TensorValue | None = None | ||
saved_object_graph_hash: TensorValue | None = None | ||
checkpoint_hash: TensorValue | None = None | ||
version: TensorValue | None = None | ||
# In practice it seems like any type is accepted, but that might cause issues later on. | ||
def __init__( | ||
self, | ||
saved_model_checksum: Integer | None = None, | ||
graph_def_program_hash: Integer | None = None, | ||
signature_def_hash: Integer | None = None, | ||
saved_object_graph_hash: Integer | None = None, | ||
checkpoint_hash: Integer | None = None, | ||
version: Integer | None = None, | ||
) -> None: ... | ||
@classmethod | ||
def from_proto(cls, proto: Incomplete) -> Self: ... | ||
def singleprint(self) -> str: ... | ||
|
||
class TrackableResource(CapturableResource): | ||
@property | ||
def resource_handle(self) -> tf.Tensor: ... | ||
def __init__(self, device: str = "") -> None: ... | ||
|
||
class VariablePolicy(Enum): | ||
EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables" | ||
NONE = None # noqa: Y026 | ||
SAVE_VARIABLE_DEVICES = "save_variable_devices" | ||
|
||
def read_fingerprint(export_dir: str) -> Fingerprint: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import abc | ||
from _typeshed import Incomplete | ||
from typing import Any, Generic, TypeVar, overload | ||
from typing_extensions import ParamSpec | ||
|
||
import tensorflow as tf | ||
from tensorflow._aliases import ContainerGeneric | ||
|
||
_P = ParamSpec("_P") | ||
_R = TypeVar("_R", covariant=True) | ||
|
||
class Callable(Generic[_P, _R], metaclass=abc.ABCMeta): | ||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ... | ||
|
||
class ConcreteFunction(Callable[_P, _R], metaclass=abc.ABCMeta): | ||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ... | ||
|
||
class GenericFunction(Callable[_P, _R], metaclass=abc.ABCMeta): | ||
@overload | ||
@abc.abstractmethod | ||
def get_concrete_function(self, *args: _P.args, **kwargs: _P.kwargs) -> ConcreteFunction[_P, _R]: ... | ||
@overload | ||
@abc.abstractmethod | ||
def get_concrete_function( | ||
self, *args: ContainerGeneric[tf.TypeSpec[Any]], **kwargs: ContainerGeneric[tf.TypeSpec[Any]] | ||
) -> ConcreteFunction[_P, _R]: ... | ||
def experimental_get_compiler_ir(self, *args: Incomplete, **kwargs: Incomplete) -> Incomplete: ... | ||
|
||
def __getattr__(name: str) -> Incomplete: ... |