Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move CrossReplicaSum from StableHLO to CHLO. #118

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stablehlo/dialect/ChloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AtanhOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(BesselI1eOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ConjOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CoshOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DigammaOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfcOp)
Expand Down
23 changes: 23 additions & 0 deletions stablehlo/dialect/ChloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -945,4 +945,27 @@ def CHLO_DynamicReshapeOp: CHLO_Op<"dynamic_reshape", [NoSideEffect,
let results = (outs HLO_Tensor:$result);
}

def CHLO_CrossReplicaSumOp : CHLO_Op<"cross_replica_sum",
[NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
let summary = "Sums input across replicated instances.";
let description = [{
For each of the replica groups, operands of the group devices are summed
so that each device has the sum.

For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`.
Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,
and `B, D, F, H` as group 1. Thus we get the outputs:
`[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`.

See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum.
}];

let arguments = (ins
HLO_Tensor:$operand,
I64ElementsAttr:$replica_groups
);

let results = (outs HLO_Tensor);
}

#endif // STABLEHLO_DIALECT_CHLO_OPS