diff --git a/stablehlo/dialect/ChloOps.cpp b/stablehlo/dialect/ChloOps.cpp index f92526443e5..6f5a842475e 100644 --- a/stablehlo/dialect/ChloOps.cpp +++ b/stablehlo/dialect/ChloOps.cpp @@ -60,6 +60,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AtanhOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(BesselI1eOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ConjOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CoshOp) +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DigammaOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfcOp) diff --git a/stablehlo/dialect/ChloOps.td b/stablehlo/dialect/ChloOps.td index 5947871c853..82c26f10f0a 100644 --- a/stablehlo/dialect/ChloOps.td +++ b/stablehlo/dialect/ChloOps.td @@ -945,4 +945,27 @@ def CHLO_DynamicReshapeOp: CHLO_Op<"dynamic_reshape", [NoSideEffect, let results = (outs HLO_Tensor:$result); } +def CHLO_CrossReplicaSumOp : CHLO_Op<"cross_replica_sum", + [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { + let summary = "Sums input across replicated instances."; + let description = [{ + For each of the replica groups, operands of the group devices are summed + so that each device has the sum. + + For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`. + Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, + and `B, D, F, H` as group 1. Thus we get the outputs: + `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`. + + See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum. + }]; + + let arguments = (ins + HLO_Tensor:$operand, + I64ElementsAttr:$replica_groups + ); + + let results = (outs HLO_Tensor); +} + #endif // STABLEHLO_DIALECT_CHLO_OPS