Skip to content

Commit

Permalink
Add an optional MemoryKind parameter to IFRT IR shardings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631235085
  • Loading branch information
ICGog authored and copybara-github committed May 7, 2024
1 parent f0eac55 commit 873d097
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 9 deletions.
1 change: 1 addition & 0 deletions xla/python/ifrt/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ cc_library(
":ifrt_interfaces_inc_gen",
":ifrt_ops_inc_gen",
":sharding_param",
"//xla/python/ifrt",
"@com_google_absl//absl/status:statusor",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CallOpInterfaces", # buildcleaner: keep
Expand Down
15 changes: 13 additions & 2 deletions xla/python/ifrt/ir/ifrt_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ limitations under the License.
#include "xla/python/ifrt/ir/constants.h"
#include "xla/python/ifrt/ir/ifrt_interfaces.h"
#include "xla/python/ifrt/ir/ifrt_ops.h"
#include "xla/python/ifrt/ir/sharding_param.h"
#include "xla/python/ifrt/memory.h"

// Generated definitions.
#include "xla/python/ifrt/ir/ifrt_dialect.cc.inc"
#include "xla/python/ifrt/ir/sharding_param.h"
#define GET_TYPEDEF_CLASSES
#include "xla/python/ifrt/ir/ifrt_types.cc.inc"
#define GET_ATTRDEF_CLASSES
Expand Down Expand Up @@ -127,7 +128,7 @@ mlir::LogicalResult IfrtDialect::verifyRegionArgAttribute(

mlir::LogicalResult IfrtShardingParamAttr::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
ShardingParam sharding_param) {
ShardingParam sharding_param, mlir::StringAttr memory_kind) {
return sharding_param.verify(emitError);
}

Expand All @@ -154,6 +155,12 @@ int IfrtShardingParamAttr::NumDevices() const {
return getSharding().NumDevices();
};

xla::ifrt::MemoryKind IfrtShardingParamAttr::MemoryKind() const {
return getMemoryKind() == nullptr
? xla::ifrt::MemoryKind()
: xla::ifrt::MemoryKind(getMemoryKind().str());
};

//===----------------------------------------------------------------------===//
// IfrtUnspecifiedShardingAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -185,6 +192,10 @@ IfrtUnspecifiedShardingAttr::LocalShapeFromGlobalShape(

int IfrtUnspecifiedShardingAttr::NumDevices() const { return 0; }

xla::ifrt::MemoryKind IfrtUnspecifiedShardingAttr::MemoryKind() const {
return xla::ifrt::MemoryKind();
}

//===----------------------------------------------------------------------===//
// IfrtArrayType
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 20 additions & 2 deletions xla/python/ifrt/ir/ifrt_dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,26 @@ def Ifrt_ShardingParamAttr : AttrDef<Ifrt_Dialect, "IfrtShardingParam", [
let mnemonic = "sharding_param";
let summary = "ShardingParam as an attribute.";

let parameters = (ins Ifrt_ShardingParameter:$sharding);
let assemblyFormat = "`<` $sharding `>`";
let parameters = (ins
Ifrt_ShardingParameter:$sharding,
OptionalParameter<"::mlir::StringAttr">:$memory_kind
);
let assemblyFormat = [{
`<` $sharding (`,` `memory_kind` `=` $memory_kind^)? `>`
}];

let builders = [
AttrBuilder<(ins "::xla::ifrt::ShardingParam":$sharding), [{
return $_get($_ctxt, sharding, /*memory_kind=*/nullptr);
}]>,
AttrBuilder<(ins
"::xla::ifrt::ShardingParam":$sharding,
"::mlir::StringRef":$memory_kind), [{
return $_get($_ctxt,
sharding,
::mlir::StringAttr::get($_ctxt, memory_kind));
}]>
];

let genVerifyDecl = 1;
}
Expand Down
5 changes: 3 additions & 2 deletions xla/python/ifrt/ir/ifrt_interfaces.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#include "xla/python/ifrt/ir/constants.h"
#include "xla/python/ifrt/ir/ifrt_dialect.h"
/* Copyright 2023 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -18,11 +16,14 @@ limitations under the License.
#ifndef XLA_PYTHON_IFRT_IR_IFRT_INTERFACES_H_
#define XLA_PYTHON_IFRT_IR_IFRT_INTERFACES_H_

#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "xla/python/ifrt/ir/constants.h"
#include "xla/python/ifrt/ir/sharding_param.h"
#include "xla/python/ifrt/memory.h"

namespace mlir {
namespace OpTrait {
Expand Down
6 changes: 6 additions & 0 deletions xla/python/ifrt/ir/ifrt_interfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def Ifrt_ShardingAttrInterface : Ifrt_AttrInterface<"IfrtShardingAttrInterface">
/*retTy=*/"int",
/*methodName=*/"NumDevices",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/"Returns the memory kind.",
/*retTy=*/"::xla::ifrt::MemoryKind",
/*methodName=*/"MemoryKind",
/*args=*/(ins)
>
];
}
Expand Down
4 changes: 2 additions & 2 deletions xla/python/ifrt/ir/tests/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("//xla:lit.bzl", "enforce_glob", "lit_test_suite")
load("//xla:xla.bzl", "xla_cc_test")
load("//xla:xla.bzl", "xla_cc_binary", "xla_cc_test")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand Down Expand Up @@ -35,7 +35,7 @@ lit_test_suite(
],
)

cc_binary(
xla_cc_binary(
name = "ifrt-opt",
srcs = ["ifrt-opt.cc"],
deps = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: ifrt-opt %s -ifrt-verify-sharding-specified -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @good_arrays
#sharding = #ifrt.sharding_param<2 to [0] on 2>
#sharding = #ifrt.sharding_param<2 to [0] on 2, memory_kind = "device">
module @good_arrays {
func.func @main(%arg0: !ifrt.array<tensor<2xi32>, #sharding, [0,1]>)
-> !ifrt.array<tensor<2xi32>, #sharding, [2,3]>
Expand Down

0 comments on commit 873d097

Please sign in to comment.