From 873d09720f83cbbebf2a2a381c09be8fa0934b36 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Mon, 6 May 2024 17:24:55 -0700 Subject: [PATCH] Add an optional MemoryKind parameter to IFRT IR shardings. PiperOrigin-RevId: 631235085 --- xla/python/ifrt/ir/BUILD | 1 + xla/python/ifrt/ir/ifrt_dialect.cc | 15 +++++++++++-- xla/python/ifrt/ir/ifrt_dialect.td | 22 +++++++++++++++++-- xla/python/ifrt/ir/ifrt_interfaces.h | 5 +++-- xla/python/ifrt/ir/ifrt_interfaces.td | 6 +++++ xla/python/ifrt/ir/tests/BUILD | 4 ++-- .../tests/ifrt_verify_sharding_specified.mlir | 2 +- 7 files changed, 46 insertions(+), 9 deletions(-) diff --git a/xla/python/ifrt/ir/BUILD b/xla/python/ifrt/ir/BUILD index be4c8acf1bb9b..3aca84642799a 100644 --- a/xla/python/ifrt/ir/BUILD +++ b/xla/python/ifrt/ir/BUILD @@ -141,6 +141,7 @@ cc_library( ":ifrt_interfaces_inc_gen", ":ifrt_ops_inc_gen", ":sharding_param", + "//xla/python/ifrt", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:CallOpInterfaces", # buildcleaner: keep diff --git a/xla/python/ifrt/ir/ifrt_dialect.cc b/xla/python/ifrt/ir/ifrt_dialect.cc index d4314487b252a..3d831de2d766e 100644 --- a/xla/python/ifrt/ir/ifrt_dialect.cc +++ b/xla/python/ifrt/ir/ifrt_dialect.cc @@ -39,10 +39,11 @@ limitations under the License. #include "xla/python/ifrt/ir/constants.h" #include "xla/python/ifrt/ir/ifrt_interfaces.h" #include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/sharding_param.h" +#include "xla/python/ifrt/memory.h" // Generated definitions. #include "xla/python/ifrt/ir/ifrt_dialect.cc.inc" -#include "xla/python/ifrt/ir/sharding_param.h" #define GET_TYPEDEF_CLASSES #include "xla/python/ifrt/ir/ifrt_types.cc.inc" #define GET_ATTRDEF_CLASSES @@ -127,7 +128,7 @@ mlir::LogicalResult IfrtDialect::verifyRegionArgAttribute( mlir::LogicalResult IfrtShardingParamAttr::verify( llvm::function_ref emitError, - ShardingParam sharding_param) { + ShardingParam sharding_param, mlir::StringAttr memory_kind) { return sharding_param.verify(emitError); } @@ -154,6 +155,12 @@ int IfrtShardingParamAttr::NumDevices() const { return getSharding().NumDevices(); }; +xla::ifrt::MemoryKind IfrtShardingParamAttr::MemoryKind() const { + return getMemoryKind() == nullptr + ? xla::ifrt::MemoryKind() + : xla::ifrt::MemoryKind(getMemoryKind().str()); +}; + //===----------------------------------------------------------------------===// // IfrtUnspecifiedShardingAttr //===----------------------------------------------------------------------===// @@ -185,6 +192,10 @@ IfrtUnspecifiedShardingAttr::LocalShapeFromGlobalShape( int IfrtUnspecifiedShardingAttr::NumDevices() const { return 0; } +xla::ifrt::MemoryKind IfrtUnspecifiedShardingAttr::MemoryKind() const { + return xla::ifrt::MemoryKind(); +} + //===----------------------------------------------------------------------===// // IfrtArrayType //===----------------------------------------------------------------------===// diff --git a/xla/python/ifrt/ir/ifrt_dialect.td b/xla/python/ifrt/ir/ifrt_dialect.td index 044668aa91511..34395a1f26791 100644 --- a/xla/python/ifrt/ir/ifrt_dialect.td +++ b/xla/python/ifrt/ir/ifrt_dialect.td @@ -74,8 +74,26 @@ def Ifrt_ShardingParamAttr : AttrDef:$memory_kind + ); + let assemblyFormat = [{ + `<` $sharding (`,` `memory_kind` `=` $memory_kind^)? `>` + }]; + + let builders = [ + AttrBuilder<(ins "::xla::ifrt::ShardingParam":$sharding), [{ + return $_get($_ctxt, sharding, /*memory_kind=*/nullptr); + }]>, + AttrBuilder<(ins + "::xla::ifrt::ShardingParam":$sharding, + "::mlir::StringRef":$memory_kind), [{ + return $_get($_ctxt, + sharding, + ::mlir::StringAttr::get($_ctxt, memory_kind)); + }]> + ]; let genVerifyDecl = 1; } diff --git a/xla/python/ifrt/ir/ifrt_interfaces.h b/xla/python/ifrt/ir/ifrt_interfaces.h index b496f4b57d3be..77d754bbb4a39 100644 --- a/xla/python/ifrt/ir/ifrt_interfaces.h +++ b/xla/python/ifrt/ir/ifrt_interfaces.h @@ -1,5 +1,3 @@ -#include "xla/python/ifrt/ir/constants.h" -#include "xla/python/ifrt/ir/ifrt_dialect.h" /* Copyright 2023 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,11 +16,14 @@ limitations under the License. #ifndef XLA_PYTHON_IFRT_IR_IFRT_INTERFACES_H_ #define XLA_PYTHON_IFRT_IR_IFRT_INTERFACES_H_ +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "xla/python/ifrt/ir/constants.h" #include "xla/python/ifrt/ir/sharding_param.h" +#include "xla/python/ifrt/memory.h" namespace mlir { namespace OpTrait { diff --git a/xla/python/ifrt/ir/ifrt_interfaces.td b/xla/python/ifrt/ir/ifrt_interfaces.td index f1a94a517b91b..f2dfe981a8995 100644 --- a/xla/python/ifrt/ir/ifrt_interfaces.td +++ b/xla/python/ifrt/ir/ifrt_interfaces.td @@ -108,6 +108,12 @@ def Ifrt_ShardingAttrInterface : Ifrt_AttrInterface<"IfrtShardingAttrInterface"> /*retTy=*/"int", /*methodName=*/"NumDevices", /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/"Returns the memory kind.", + /*retTy=*/"::xla::ifrt::MemoryKind", + /*methodName=*/"MemoryKind", + /*args=*/(ins) > ]; } diff --git a/xla/python/ifrt/ir/tests/BUILD b/xla/python/ifrt/ir/tests/BUILD index d954e0117964d..44ca08ac9081e 100644 --- a/xla/python/ifrt/ir/tests/BUILD +++ b/xla/python/ifrt/ir/tests/BUILD @@ -1,5 +1,5 @@ load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") -load("//xla:xla.bzl", "xla_cc_test") +load("//xla:xla.bzl", "xla_cc_binary", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -35,7 +35,7 @@ lit_test_suite( ], ) -cc_binary( +xla_cc_binary( name = "ifrt-opt", srcs = ["ifrt-opt.cc"], deps = [ diff --git a/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir b/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir index 6be0131833826..e79d8816c7b93 100644 --- a/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_verify_sharding_specified.mlir @@ -1,7 +1,7 @@ // RUN: ifrt-opt %s -ifrt-verify-sharding-specified -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: @good_arrays -#sharding = #ifrt.sharding_param<2 to [0] on 2> +#sharding = #ifrt.sharding_param<2 to [0] on 2, memory_kind = "device"> module @good_arrays { func.func @main(%arg0: !ifrt.array, #sharding, [0,1]>) -> !ifrt.array, #sharding, [2,3]>