From 4bb1902676dc41a3d01e2b468a9d3e2cef17c69f Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Tue, 6 Feb 2024 19:02:11 +0900 Subject: [PATCH 1/2] feat: add tf.ones, tf.zeros, tf.zeros_like and tf.ones_like --- stubs/tensorflow/tensorflow/__init__.pyi | 19 +++++++++++++++++++ .../tensorflow/experimental/dtensor.pyi | 3 +++ 2 files changed, 22 insertions(+) create mode 100644 stubs/tensorflow/tensorflow/experimental/dtensor.pyi diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index 1021bff117a7..e7807bce51e4 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -36,6 +36,7 @@ from tensorflow.core.protobuf import struct_pb2 # is necessary to avoid a crash in pytype. from tensorflow.dtypes import * from tensorflow.dtypes import DType as DType +from tensorflow.experimental.dtensor import Layout from tensorflow.keras import losses as losses # Most tf.math functions are exported as tf, but sadly not all are. @@ -439,4 +440,22 @@ def cast(x: TensorCompatible, dtype: DTypeLike, name: str | None = None) -> Tens def cast(x: SparseTensor, dtype: DTypeLike, name: str | None = None) -> SparseTensor: ... @overload def cast(x: RaggedTensor, dtype: DTypeLike, name: str | None = None) -> RaggedTensor: ... +def zeros(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ... +def ones(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ... +@overload +def zeros_like( + input: TensorCompatible | IndexedSlices, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> Tensor: ... +@overload +def zeros_like( + input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> RaggedTensor: ... +@overload +def ones_like( + input: TensorCompatible, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> Tensor: ... +@overload +def ones_like( + input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> RaggedTensor: ... def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ... diff --git a/stubs/tensorflow/tensorflow/experimental/dtensor.pyi b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi new file mode 100644 index 000000000000..d5d3d3716ad0 --- /dev/null +++ b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi @@ -0,0 +1,3 @@ +from _typeshed import Incomplete + +Layout = Incomplete From e2bf4c4f2fe3c3a7405c4093730acfd1da3c007e Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Wed, 7 Feb 2024 22:01:37 +0900 Subject: [PATCH 2/2] add __getattr__ --- stubs/tensorflow/tensorflow/__init__.pyi | 2 +- stubs/tensorflow/tensorflow/experimental/dtensor.pyi | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index e7807bce51e4..cb458d23196f 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -407,7 +407,6 @@ class RaggedTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]): @classmethod def from_value(cls, value: RaggedTensor) -> Self: ... -def __getattr__(name: str) -> Incomplete: ... def convert_to_tensor( value: TensorCompatible | IndexedSlices, dtype: DTypeLike | None = None, @@ -459,3 +458,4 @@ def ones_like( input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None ) -> RaggedTensor: ... def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ... +def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/experimental/dtensor.pyi b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi index d5d3d3716ad0..92380280c699 100644 --- a/stubs/tensorflow/tensorflow/experimental/dtensor.pyi +++ b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi @@ -1,3 +1,5 @@ from _typeshed import Incomplete Layout = Incomplete + +def __getattr__(name: str) -> Incomplete: ...