Skip to content

Commit

Permalink
Created FinalizeTPUEmbeddingV2 to output embedding_partitions and…
Browse files Browse the repository at this point in the history
… `hbm_buffers_config`, and created V2 Ops for BC XLA Ops which accept `embedding_partitions`, `hbm_buffers_config` and the serialization of `TpuTopologyArgsProto`.

PiperOrigin-RevId: 617397729
  • Loading branch information
Dateng Lin authored and copybara-github committed Mar 20, 2024
1 parent eac67fd commit 5e5c7b9
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions xla/stream_executor/tpu/tpu_ops_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,9 @@ typedef struct TpuEmbeddingEngine_RecvActivationsComputation_Params {
void* priv;

TpuSerializedProto tpu_embedding_config;
TpuSerializedProto embedding_partitions;
TpuSerializedProto hbm_buffers_config;
TpuSerializedProto tpu_topology;
XLA_Shape* deduplication_data_shape;
TpuSerializedProto* op_sharding;

Expand All @@ -652,6 +655,9 @@ typedef struct
void* priv;

TpuSerializedProto tpu_embedding_config;
TpuSerializedProto embedding_partitions;
TpuSerializedProto hbm_buffers_config;
TpuSerializedProto tpu_topology;
TpuSerializedProto* op_sharding;
// out
TpuSerializedProto* xla_computation;
Expand All @@ -669,6 +675,9 @@ typedef struct TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params {

int32_t num_inputs;
TpuSerializedProto tpu_embedding_config;
TpuSerializedProto embedding_partitions;
TpuSerializedProto hbm_buffers_config;
TpuSerializedProto tpu_topology;
XLA_Shape* learning_rate_tuple_shape;
XLA_Shape* deduplication_data_shape;
XLA_Shape* gradient_tuple_shape;
Expand All @@ -686,6 +695,9 @@ typedef struct TpuEmbeddingEngine_DedupDataSizeComputation_Params {
void* priv;

TpuSerializedProto tpu_embedding_config;
TpuSerializedProto embedding_partitions;
TpuSerializedProto hbm_buffers_config;
TpuSerializedProto tpu_topology;
// out
int32_t* num_elements;
TF_Status* status;
Expand All @@ -699,6 +711,9 @@ typedef struct TpuEmbeddingEngine_DedupDataTupleMaskComputation_Params {
void* priv;

TpuSerializedProto tpu_embedding_config;
TpuSerializedProto embedding_partitions;
TpuSerializedProto hbm_buffers_config;
TpuSerializedProto tpu_topology;
// out
TpuSerializedProto* xla_computation;
TF_Status* status;
Expand Down

0 comments on commit 5e5c7b9

Please sign in to comment.