From 5e5c7b97d21e3c182622ddf11202038d74d4972a Mon Sep 17 00:00:00 2001 From: Dateng Lin Date: Tue, 19 Mar 2024 22:21:22 -0700 Subject: [PATCH] Created `FinalizeTPUEmbeddingV2` to output `embedding_partitions` and `hbm_buffers_config`, and created V2 Ops for BC XLA Ops which accept `embedding_partitions`, `hbm_buffers_config` and the serialization of `TpuTopologyArgsProto`. PiperOrigin-RevId: 617397729 --- xla/stream_executor/tpu/tpu_ops_c_api.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/xla/stream_executor/tpu/tpu_ops_c_api.h b/xla/stream_executor/tpu/tpu_ops_c_api.h index 0db1b51f91f38..80365ebb046a2 100644 --- a/xla/stream_executor/tpu/tpu_ops_c_api.h +++ b/xla/stream_executor/tpu/tpu_ops_c_api.h @@ -635,6 +635,9 @@ typedef struct TpuEmbeddingEngine_RecvActivationsComputation_Params { void* priv; TpuSerializedProto tpu_embedding_config; + TpuSerializedProto embedding_partitions; + TpuSerializedProto hbm_buffers_config; + TpuSerializedProto tpu_topology; XLA_Shape* deduplication_data_shape; TpuSerializedProto* op_sharding; @@ -652,6 +655,9 @@ typedef struct void* priv; TpuSerializedProto tpu_embedding_config; + TpuSerializedProto embedding_partitions; + TpuSerializedProto hbm_buffers_config; + TpuSerializedProto tpu_topology; TpuSerializedProto* op_sharding; // out TpuSerializedProto* xla_computation; @@ -669,6 +675,9 @@ typedef struct TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params { int32_t num_inputs; TpuSerializedProto tpu_embedding_config; + TpuSerializedProto embedding_partitions; + TpuSerializedProto hbm_buffers_config; + TpuSerializedProto tpu_topology; XLA_Shape* learning_rate_tuple_shape; XLA_Shape* deduplication_data_shape; XLA_Shape* gradient_tuple_shape; @@ -686,6 +695,9 @@ typedef struct TpuEmbeddingEngine_DedupDataSizeComputation_Params { void* priv; TpuSerializedProto tpu_embedding_config; + TpuSerializedProto embedding_partitions; + TpuSerializedProto hbm_buffers_config; + TpuSerializedProto tpu_topology; // out int32_t* num_elements; TF_Status* status; @@ -699,6 +711,9 @@ typedef struct TpuEmbeddingEngine_DedupDataTupleMaskComputation_Params { void* priv; TpuSerializedProto tpu_embedding_config; + TpuSerializedProto embedding_partitions; + TpuSerializedProto hbm_buffers_config; + TpuSerializedProto tpu_topology; // out TpuSerializedProto* xla_computation; TF_Status* status;