Skip to content

Commit

Permalink
Extend Attribute Scalars by their unsigned types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 632222737
  • Loading branch information
Paweł Paruzel authored and copybara-github committed May 9, 2024
1 parent 7c5b61d commit d60579f
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 18 deletions.
16 changes: 10 additions & 6 deletions xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,13 @@ struct CallFrame::Dictionary {
};

struct CallFrame::Array {
std::variant<std::vector<int8_t>, std::vector<int16_t>, std::vector<int32_t>,
std::vector<int64_t>, std::vector<float>,
std::vector<double>>
value; // XLA_FFI_Array::data
CallFrameBuilder::Array value; // XLA_FFI_Array::data

XLA_FFI_Array array = {XLA_FFI_Array_STRUCT_SIZE, nullptr};
};

struct CallFrame::Scalar {
std::variant<bool, int8_t, int16_t, int32_t, int64_t, float, double>
value; // XLA_FFI_Scalar::value
CallFrameBuilder::Scalar value; // XLA_FFI_Scalar::value

XLA_FFI_Scalar scalar = {XLA_FFI_Scalar_STRUCT_SIZE, nullptr};
};
Expand Down Expand Up @@ -372,6 +368,14 @@ static XLA_FFI_DataType GetDataType() {
return XLA_FFI_DataType_S32;
} else if constexpr (std::is_same_v<int64_t, T>) {
return XLA_FFI_DataType_S64;
} else if constexpr (std::is_same_v<uint8_t, T>) {
return XLA_FFI_DataType_U8;
} else if constexpr (std::is_same_v<uint16_t, T>) {
return XLA_FFI_DataType_U16;
} else if constexpr (std::is_same_v<uint32_t, T>) {
return XLA_FFI_DataType_U32;
} else if constexpr (std::is_same_v<uint64_t, T>) {
return XLA_FFI_DataType_U64;
} else if constexpr (std::is_same_v<float, T>) {
return XLA_FFI_DataType_F32;
} else if constexpr (std::is_same_v<double, T>) {
Expand Down
6 changes: 4 additions & 2 deletions xla/ffi/call_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ class CallFrameBuilder {
CallFrameBuilder(CallFrameBuilder&&);
CallFrameBuilder& operator=(CallFrameBuilder&&);

using Scalar =
std::variant<bool, int8_t, int16_t, int32_t, int64_t, float, double>;
using Scalar = std::variant<bool, int8_t, int16_t, int32_t, int64_t, uint8_t,
uint16_t, uint32_t, uint64_t, float, double>;
using Array = std::variant<std::vector<int8_t>, std::vector<int16_t>,
std::vector<int32_t>, std::vector<int64_t>,
std::vector<uint8_t>, std::vector<uint16_t>,
std::vector<uint32_t>, std::vector<uint64_t>,
std::vector<float>, std::vector<double>>;

// Declare implementation detail structs for call frame builder storage.
Expand Down
55 changes: 45 additions & 10 deletions xla/service/cpu/runtime_handle_ffi_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,51 @@ absl::StatusOr<AttributesMap> BuildAttributesMap(mlir::DictionaryAttr dict) {
for (auto& kv : dict) {
std::string_view name = kv.getName().strref();

auto boolean = [&](mlir::BoolAttr boolean) {
attributes[name] = static_cast<bool>(boolean.getValue());
return absl::OkStatus();
};

auto integer = [&](mlir::IntegerAttr integer) {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 32:
attributes[name] = static_cast<int32_t>(integer.getInt());
return absl::OkStatus();
case 64:
attributes[name] = static_cast<int64_t>(integer.getInt());
return absl::OkStatus();
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ", name));
const bool is_unsigned = integer.getType().isUnsignedInteger();
if (is_unsigned) {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
attributes[name] = static_cast<uint8_t>(integer.getUInt());
return absl::OkStatus();
case 16:
attributes[name] = static_cast<uint16_t>(integer.getUInt());
return absl::OkStatus();
case 32:
attributes[name] = static_cast<uint32_t>(integer.getUInt());
return absl::OkStatus();
case 64:
attributes[name] = static_cast<uint64_t>(integer.getUInt());
return absl::OkStatus();
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ",
name));
}
} else {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
attributes[name] = static_cast<int8_t>(integer.getInt());
return absl::OkStatus();
case 16:
attributes[name] = static_cast<int16_t>(integer.getInt());
return absl::OkStatus();
case 32:
attributes[name] = static_cast<int32_t>(integer.getInt());
return absl::OkStatus();
case 64:
attributes[name] = static_cast<int64_t>(integer.getInt());
return absl::OkStatus();
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported integer attribute bit width for attribute: ",
name));
}
}
};

Expand All @@ -87,6 +121,7 @@ absl::StatusOr<AttributesMap> BuildAttributesMap(mlir::DictionaryAttr dict) {

TF_RETURN_IF_ERROR(
llvm::TypeSwitch<mlir::Attribute, absl::Status>(kv.getValue())
.Case<mlir::BoolAttr>(boolean)
.Case<mlir::IntegerAttr>(integer)
.Case<mlir::FloatAttr>(fp)
.Case<mlir::StringAttr>(str)
Expand Down

0 comments on commit d60579f

Please sign in to comment.