Skip to content

Commit

Permalink
tensorflow add tensorflow.saved_model (#11439)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoel-bagard authored Mar 13, 2024
1 parent 8a9dcb1 commit 56bb53a
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 1 deletion.
6 changes: 6 additions & 0 deletions stubs/tensorflow/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,11 @@ tensorflow.train.Int64List.*
tensorflow.train.ClusterDef.*
tensorflow.train.ServerDef.*

# The python module cannot be accessed directly, so to stubtest it appears that it is not present at runtime.
# However it can be accessed by doing:
# from tensorflow import python
# python.X
tensorflow.python.*

# sigmoid_cross_entropy_with_logits has default values (None), however those values are not valid.
tensorflow.nn.sigmoid_cross_entropy_with_logits
3 changes: 2 additions & 1 deletion stubs/tensorflow/tensorflow/_aliases.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class KerasSerializable2(Protocol):

KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2

Integer: TypeAlias = tf.Tensor | int | IntArray | np.number[Any] # Here tf.Tensor and IntArray are assumed to be 0D.
TensorValue: TypeAlias = tf.Tensor # Alias for a 0D Tensor
Integer: TypeAlias = TensorValue | int | IntArray | np.number[Any] # Here IntArray are assumed to be 0D.
Slice: TypeAlias = int | slice | None
FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence]
IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]
Expand Down
11 changes: 11 additions & 0 deletions stubs/tensorflow/tensorflow/python/trackable/resource.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from _typeshed import Incomplete

from tensorflow.python.trackable.base import Trackable

class _ResourceMetaclass(type): ...

# Internal type that is commonly used as a base class
# it is needed for the public signatures of some APIs.
class CapturableResource(Trackable, metaclass=_ResourceMetaclass): ...

def __getattr__(name: str) -> Incomplete: ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from _typeshed import Incomplete

AutoTrackable = Incomplete
115 changes: 115 additions & 0 deletions stubs/tensorflow/tensorflow/saved_model/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, Generic, Literal, TypeVar
from typing_extensions import ParamSpec, TypeAlias

import tensorflow as tf
from tensorflow.python.training.tracking.autotrackable import AutoTrackable
from tensorflow.saved_model.experimental import VariablePolicy
from tensorflow.types.experimental import ConcreteFunction, GenericFunction

_P = ParamSpec("_P")
_R = TypeVar("_R", covariant=True)

class Asset:
@property
def asset_path(self) -> tf.Tensor: ...
def __init__(self, path: str | Path | tf.Tensor) -> None: ...

class LoadOptions:
allow_partial_checkpoint: bool
experimental_io_device: str | None
experimental_skip_checkpoint: bool
experimental_variable_policy: VariablePolicy | None
experimental_load_function_aliases: bool

def __init__(
self,
allow_partial_checkpoint: bool = False,
experimental_io_device: str | None = None,
experimental_skip_checkpoint: bool = False,
experimental_variable_policy: (
VariablePolicy | Literal["expand_distributed_variables", "save_variable_devices"] | None
) = None,
experimental_load_function_aliases: bool = False,
) -> None: ...

class SaveOptions:
__slots__ = (
"namespace_whitelist",
"save_debug_info",
"function_aliases",
"experimental_io_device",
"experimental_variable_policy",
"experimental_custom_gradients",
"experimental_image_format",
"experimental_skip_saver",
)
namespace_whitelist: list[str]
save_debug_info: bool
function_aliases: dict[str, tf.types.experimental.GenericFunction[..., object]]
experimental_io_device: str
experimental_variable_policy: VariablePolicy
experimental_custom_gradients: bool
experimental_image_format: bool
experimental_skip_saver: bool
def __init__(
self,
namespace_whitelist: list[str] | None = None,
save_debug_info: bool = False,
function_aliases: Mapping[str, tf.types.experimental.GenericFunction[..., object]] | None = None,
experimental_io_device: str | None = None,
experimental_variable_policy: str | VariablePolicy | None = None,
experimental_custom_gradients: bool = True,
experimental_image_format: bool = False,
experimental_skip_saver: bool = False,
) -> None: ...

def contains_saved_model(export_dir: str | Path) -> bool: ...

class _LoadedAttributes(Generic[_P, _R]):
signatures: Mapping[str, ConcreteFunction[_P, _R]]

class _LoadedModel(AutoTrackable, _LoadedAttributes[_P, _R]):
variables: list[tf.Variable]
trainable_variables: list[tf.Variable]
# TF1 model artifact specific
graph: tf.Graph

def load(
export_dir: str, tags: str | Sequence[str] | None = None, options: LoadOptions | None = None
) -> _LoadedModel[..., Any]: ...

_TF_Function: TypeAlias = ConcreteFunction[..., object] | GenericFunction[..., object]

def save(
obj: tf.Module,
export_dir: str,
signatures: _TF_Function | Mapping[str, _TF_Function] | None = None,
options: SaveOptions | None = None,
) -> None: ...

ASSETS_DIRECTORY: str = "assets"
ASSETS_KEY: str = "saved_model_assets"
CLASSIFY_INPUTS: str = "inputs"
CLASSIFY_METHOD_NAME: str = "tensorflow/serving/classify"
CLASSIFY_OUTPUT_CLASSES: str = "classes"
CLASSIFY_OUTPUT_SCORES: str = "scores"
DEBUG_DIRECTORY: str = "debug"
DEBUG_INFO_FILENAME_PB: str = "saved_model_debug_info.pb"
DEFAULT_SERVING_SIGNATURE_DEF_KEY: str = "serving_default"
GPU: str = "gpu"
PREDICT_INPUTS: str = "inputs"
PREDICT_METHOD_NAME: str = "tensorflow/serving/predict"
PREDICT_OUTPUTS: str = "outputs"
REGRESS_INPUTS: str = "inputs"
REGRESS_METHOD_NAME: str = "tensorflow/serving/regress"
REGRESS_OUTPUTS: str = "outputs"
SAVED_MODEL_FILENAME_PB: str = "saved_model.pb"
SAVED_MODEL_FILENAME_PBTXT: str = "saved_model.pbtxt"
SAVED_MODEL_SCHEMA_VERSION: int = 1
SERVING: str = "serve"
TPU: str = "tpu"
TRAINING: str = "train"
VARIABLES_DIRECTORY: str = "variables"
VARIABLES_FILENAME: str = "variables"
40 changes: 40 additions & 0 deletions stubs/tensorflow/tensorflow/saved_model/experimental.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from _typeshed import Incomplete
from enum import Enum
from typing_extensions import Self

import tensorflow as tf
from tensorflow._aliases import Integer, TensorValue
from tensorflow.python.trackable.resource import CapturableResource

class Fingerprint:
saved_model_checksum: TensorValue | None
graph_def_program_hash: TensorValue | None = None
signature_def_hash: TensorValue | None = None
saved_object_graph_hash: TensorValue | None = None
checkpoint_hash: TensorValue | None = None
version: TensorValue | None = None
# In practice it seems like any type is accepted, but that might cause issues later on.
def __init__(
self,
saved_model_checksum: Integer | None = None,
graph_def_program_hash: Integer | None = None,
signature_def_hash: Integer | None = None,
saved_object_graph_hash: Integer | None = None,
checkpoint_hash: Integer | None = None,
version: Integer | None = None,
) -> None: ...
@classmethod
def from_proto(cls, proto: Incomplete) -> Self: ...
def singleprint(self) -> str: ...

class TrackableResource(CapturableResource):
@property
def resource_handle(self) -> tf.Tensor: ...
def __init__(self, device: str = "") -> None: ...

class VariablePolicy(Enum):
EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables"
NONE = None # noqa: Y026
SAVE_VARIABLE_DEVICES = "save_variable_devices"

def read_fingerprint(export_dir: str) -> Fingerprint: ...
29 changes: 29 additions & 0 deletions stubs/tensorflow/tensorflow/types/experimental.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import abc
from _typeshed import Incomplete
from typing import Any, Generic, TypeVar, overload
from typing_extensions import ParamSpec

import tensorflow as tf
from tensorflow._aliases import ContainerGeneric

_P = ParamSpec("_P")
_R = TypeVar("_R", covariant=True)

class Callable(Generic[_P, _R], metaclass=abc.ABCMeta):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...

class ConcreteFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...

class GenericFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
@overload
@abc.abstractmethod
def get_concrete_function(self, *args: _P.args, **kwargs: _P.kwargs) -> ConcreteFunction[_P, _R]: ...
@overload
@abc.abstractmethod
def get_concrete_function(
self, *args: ContainerGeneric[tf.TypeSpec[Any]], **kwargs: ContainerGeneric[tf.TypeSpec[Any]]
) -> ConcreteFunction[_P, _R]: ...
def experimental_get_compiler_ir(self, *args: Incomplete, **kwargs: Incomplete) -> Incomplete: ...

def __getattr__(name: str) -> Incomplete: ...

0 comments on commit 56bb53a

Please sign in to comment.