diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index 4db2d01c736eb..867c7ccb35edf 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -133,17 +133,13 @@ struct CallFrame::Dictionary { }; struct CallFrame::Array { - std::variant, std::vector, std::vector, - std::vector, std::vector, - std::vector> - value; // XLA_FFI_Array::data + CallFrameBuilder::Array value; // XLA_FFI_Array::data XLA_FFI_Array array = {XLA_FFI_Array_STRUCT_SIZE, nullptr}; }; struct CallFrame::Scalar { - std::variant - value; // XLA_FFI_Scalar::value + CallFrameBuilder::Scalar value; // XLA_FFI_Scalar::value XLA_FFI_Scalar scalar = {XLA_FFI_Scalar_STRUCT_SIZE, nullptr}; }; @@ -372,6 +368,14 @@ static XLA_FFI_DataType GetDataType() { return XLA_FFI_DataType_S32; } else if constexpr (std::is_same_v) { return XLA_FFI_DataType_S64; + } else if constexpr (std::is_same_v) { + return XLA_FFI_DataType_U8; + } else if constexpr (std::is_same_v) { + return XLA_FFI_DataType_U16; + } else if constexpr (std::is_same_v) { + return XLA_FFI_DataType_U32; + } else if constexpr (std::is_same_v) { + return XLA_FFI_DataType_U64; } else if constexpr (std::is_same_v) { return XLA_FFI_DataType_F32; } else if constexpr (std::is_same_v) { diff --git a/xla/ffi/call_frame.h b/xla/ffi/call_frame.h index 9e8fb642c6331..9bfd31cd4326a 100644 --- a/xla/ffi/call_frame.h +++ b/xla/ffi/call_frame.h @@ -58,10 +58,12 @@ class CallFrameBuilder { CallFrameBuilder(CallFrameBuilder&&); CallFrameBuilder& operator=(CallFrameBuilder&&); - using Scalar = - std::variant; + using Scalar = std::variant; using Array = std::variant, std::vector, std::vector, std::vector, + std::vector, std::vector, + std::vector, std::vector, std::vector, std::vector>; // Declare implementation detail structs for call frame builder storage. diff --git a/xla/service/cpu/runtime_handle_ffi_call.cc b/xla/service/cpu/runtime_handle_ffi_call.cc index a962c998d3d82..d0e26cc5cfc23 100644 --- a/xla/service/cpu/runtime_handle_ffi_call.cc +++ b/xla/service/cpu/runtime_handle_ffi_call.cc @@ -55,17 +55,51 @@ absl::StatusOr BuildAttributesMap(mlir::DictionaryAttr dict) { for (auto& kv : dict) { std::string_view name = kv.getName().strref(); + auto boolean = [&](mlir::BoolAttr boolean) { + attributes[name] = static_cast(boolean.getValue()); + return absl::OkStatus(); + }; + auto integer = [&](mlir::IntegerAttr integer) { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 32: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 64: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", name)); + const bool is_unsigned = integer.getType().isUnsignedInteger(); + if (is_unsigned) { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 8: + attributes[name] = static_cast(integer.getUInt()); + return absl::OkStatus(); + case 16: + attributes[name] = static_cast(integer.getUInt()); + return absl::OkStatus(); + case 32: + attributes[name] = static_cast(integer.getUInt()); + return absl::OkStatus(); + case 64: + attributes[name] = static_cast(integer.getUInt()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", + name)); + } + } else { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 8: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 16: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 32: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 64: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", + name)); + } } }; @@ -87,6 +121,7 @@ absl::StatusOr BuildAttributesMap(mlir::DictionaryAttr dict) { TF_RETURN_IF_ERROR( llvm::TypeSwitch(kv.getValue()) + .Case(boolean) .Case(integer) .Case(fp) .Case(str)