Skip to content

Commit

Permalink
Integrate StableHLO at openxla/stablehlo@271e8634
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 620069321
  • Loading branch information
ghpvnist authored and TensorFlow MLIR Team committed Mar 28, 2024
1 parent 1185e20 commit 89dcfe6
Show file tree
Hide file tree
Showing 8 changed files with 6 additions and 167 deletions.
4 changes: 2 additions & 2 deletions stablehlo/WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "7ac7d418ac2b16fd44789dcf48e2b5d73de3e715"
LLVM_COMMIT = "3cf169ca160eaf5464503fbd93d73ee1d8597936"

LLVM_SHA256 = "8b99a146881fbb2a2d8e812724550b2c88fed4403dfb4e133ee8b7107a6a9348"
LLVM_SHA256 = "b63cac687df1bc98e3eb0289f3be6824fcb1b106d0720b5c083417918d1029fd"

http_archive(
name = "llvm-raw",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ if [[ $# -ne 0 ]] ; then
fi

echo "Gathering changed files..."
mapfile -t CHANGED_FILES < <(git diff "$BASE_BRANCH" HEAD --name-only --diff-filter=d | grep '.*\.cpp$\|.*\.h$\|.*\.md$\|.*\.mlir$\|.*\.sh$\|.*\.td$\|.*\.txt$\|.*\.yml$\|.*\.yaml$')
mapfile -t CHANGED_FILES < <(git diff "$BASE_BRANCH" HEAD --name-only --diff-filter=d | grep -Ev '.*\.(bc|png|svg)$')
if (( ${#CHANGED_FILES[@]} == 0 )); then
echo "No files to check."
exit 0
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/examples/c++/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package(
)

cc_binary(
name = "example-add",
name = "example_add",
srcs = [
"ExampleAdd.cpp",
],
Expand All @@ -33,7 +33,7 @@ cc_binary(
)

cc_test(
name = "example-add-test",
name = "example_add_test",
srcs = [
"ExampleAdd.cpp",
],
Expand Down
35 changes: 0 additions & 35 deletions stablehlo/stablehlo/dialect/ChloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,41 +412,6 @@ LogicalResult BroadcastSelectOp::reifyReturnTypeShapes(
return success();
}

//===----------------------------------------------------------------------===//
// RankSpecializationClusterOp
//===----------------------------------------------------------------------===//

void RankSpecializationClusterOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor>& regions) {
// RankSpecializationClusterOp has unconditional control flows into the region
// and back to the parent, so return the correct RegionSuccessor purely based
// on the index being None or 0.
if (!point.isParent()) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
regions.push_back(RegionSuccessor(&getBody()));
}

LogicalResult RankSpecializationClusterOp::verify() {
Block* body = SingleBlock::getBody();
if (body->getArgumentTypes() != getOperandTypes())
return emitOpError() << "block argument types must match operand types";

// All operands of nested ops must be defined in the body or declared by the
// cluster.
for (Operation& nested : body->without_terminator()) {
if (!llvm::all_of(nested.getOpOperands(), [&](OpOperand& operand) {
Operation* def = operand.get().getDefiningOp();
if (def != nullptr && def->getBlock() == body) return true;
return llvm::is_contained(body->getArguments(), operand.get());
}))
return emitOpError() << "nested ops must not depend on implicit operands";
}

return success();
}

//===----------------------------------------------------------------------===//
// TopKOp
//===----------------------------------------------------------------------===//
Expand Down
53 changes: 0 additions & 53 deletions stablehlo/stablehlo/dialect/ChloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -882,59 +882,6 @@ def CHLO_MinimumBroadcastShapesOp :
let hasVerifier = 1;
}

def CHLO_RankSpecializationClusterOp
: CHLO_Op<"rank_specialization_cluster", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">,
RecursiveMemoryEffects]> {

let summary = "Cluster of operations that will be rank-specialized together.";

let description = [{
Groups compatible element-wise operatons together so that they can be
rank-specialized together. The operation takes and yields a variadic number
of (unranked) tensor operands. Its body region holds one block with one
block argument per input tensor of the same type. All operations in this
block must only operate on these block arguments. Results are returned
through the `rank_specialization_cluster_yield` operation.

Example:

```
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
%1 = chlo.broadcast_multiply %arg0_, %arg1_
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = chlo.broadcast_add %1, %arg2_
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
```
}];

let arguments = (ins Variadic<HLO_AnyTensor>);
let results = (outs Variadic<HLO_AnyTensor>);
let regions = (region SizedRegion<1>:$body);

let hasVerifier = 1;
}

def CHLO_RankSpecializationClusterYieldOp
: CHLO_Op<"rank_specialization_cluster_yield", [Pure,
ReturnLike, Terminator, HasParent<"RankSpecializationClusterOp">]> {

let summary = "Yield operation for `rank_specialization_cluster`";
let description = [{
This operation yields the results from within the
`chlo.rank_specialization_cluster` operation's region. The operation takes
an arbitrary number of operands and produces no results. The operand number
and types must match the number and types of the parent
`rank_specialization_cluster` operation's results.
}];

let arguments = (ins Variadic<HLO_AnyTensor>:$results);
}

def CHLO_DynamicReshapeOp: CHLO_Op<"dynamic_reshape", [Pure,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>]> {
let summary = "Reshape a tensor to a given, possibly dynamic, shape.";
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(0, 19, 3); }
static Version getCurrentVersion() { return Version(0, 19, 4); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
55 changes: 0 additions & 55 deletions stablehlo/stablehlo/tests/ops_chlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -100,61 +100,6 @@ func.func @minimum_broadcast_shapes_one_operand(%arg: tensor<?xindex>) {

// -----

func.func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
%1 = chlo.broadcast_multiply %arg0_, %arg1_
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = chlo.broadcast_add %1, %arg2_
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @rank_specialization_cluster(%arg0 : tensor<*xf32>,
%arg1 : tensor<*xf32>) -> tensor<*xf32> {
// expected-error @+1{{source has 2 operands, but target successor needs 1}}
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>):
"chlo.rank_specialization_cluster_yield"(%arg0_, %arg1_)
: (tensor<*xf32>, tensor<*xf32>) -> ()
}) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
// expected-error @+1{{block argument types must match operand types}}
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>):
"chlo.rank_specialization_cluster_yield"(%arg0_) : (tensor<*xf32>) -> ()
}) : (tensor<*xf32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
// expected-error @+1{{nested ops must not depend on implicit operands}}
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
%1 = chlo.broadcast_multiply %arg0_, %arg1_
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = chlo.broadcast_add %1, %arg2
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
"chlo.rank_specialization_cluster_yield"(%2) : (tensor<*xf32>) -> ()
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @top_k(%arg0 : tensor<f32>) {
// expected-error @+2 {{failed to infer returned types}}
// @expected-error @+1{{operand's rank must be at least 1}}
Expand Down
18 changes: 0 additions & 18 deletions stablehlo/stablehlo/tests/ops_chlo_roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -417,24 +417,6 @@ func.func @chlo_reshape_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) -> t
func.return %0 : tensor<?x?xf32>
}

// CHECK-LABEL: func @chlo_rank_specialization_cluster
// CHECK-SAME: %[[A0:.*0]]: tensor<*xf32>,
// CHECK-SAME: %[[A1:.*1]]: tensor<*xf32>,
// CHECK-SAME: %[[A2:.*2]]: tensor<*xf32>)
// CHECK-NEXT: %[[T:.*]] = "chlo.rank_specialization_cluster"(%[[A0]], %[[A1]], %[[A2]])
// CHECK: ^bb0(%[[A3:.*]]: tensor<*xf32>, %[[A4:.*]]: tensor<*xf32>, %[[A5:.*]]: tensor<*xf32>):
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[A3]]) : (tensor<*xf32>) -> ()
// CHECK: }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[T]] : tensor<*xf32>
func.func @chlo_rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
%0 = "chlo.rank_specialization_cluster"(%arg0, %arg1, %arg2) ({
^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>, %arg2_ : tensor<*xf32>):
"chlo.rank_specialization_cluster_yield"(%arg0_) : (tensor<*xf32>) -> ()
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @chlo_erf_inv
// CHECK-SAME: %[[A0:.*0]]: tensor<16x16xf32>)
// CHECK: chlo.erf_inv %[[A0]] : tensor<16x16xf32> -> tensor<16x16xf32>
Expand Down

0 comments on commit 89dcfe6

Please sign in to comment.