diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index 1021bff117a7..cb458d23196f 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -36,6 +36,7 @@ from tensorflow.core.protobuf import struct_pb2 # is necessary to avoid a crash in pytype. from tensorflow.dtypes import * from tensorflow.dtypes import DType as DType +from tensorflow.experimental.dtensor import Layout from tensorflow.keras import losses as losses # Most tf.math functions are exported as tf, but sadly not all are. @@ -406,7 +407,6 @@ class RaggedTensorSpec(TypeSpec[struct_pb2.TypeSpecProto]): @classmethod def from_value(cls, value: RaggedTensor) -> Self: ... -def __getattr__(name: str) -> Incomplete: ... def convert_to_tensor( value: TensorCompatible | IndexedSlices, dtype: DTypeLike | None = None, @@ -439,4 +439,23 @@ def cast(x: TensorCompatible, dtype: DTypeLike, name: str | None = None) -> Tens def cast(x: SparseTensor, dtype: DTypeLike, name: str | None = None) -> SparseTensor: ... @overload def cast(x: RaggedTensor, dtype: DTypeLike, name: str | None = None) -> RaggedTensor: ... +def zeros(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ... +def ones(shape: ShapeLike, dtype: DTypeLike = ..., name: str | None = None, layout: Layout | None = None) -> Tensor: ... +@overload +def zeros_like( + input: TensorCompatible | IndexedSlices, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> Tensor: ... +@overload +def zeros_like( + input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> RaggedTensor: ... +@overload +def ones_like( + input: TensorCompatible, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> Tensor: ... +@overload +def ones_like( + input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None +) -> RaggedTensor: ... def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ... +def __getattr__(name: str) -> Incomplete: ... diff --git a/stubs/tensorflow/tensorflow/experimental/dtensor.pyi b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi index ffb22d344b55..178d1211f45a 100644 --- a/stubs/tensorflow/tensorflow/experimental/dtensor.pyi +++ b/stubs/tensorflow/tensorflow/experimental/dtensor.pyi @@ -2,6 +2,8 @@ from _typeshed import Incomplete from tensorflow._aliases import IntArray, IntDataSequence +Layout = Incomplete + class Mesh: def __init__( self,