diff --git a/meson.build b/meson.build
index ef6389bac5..7ff7393856 100644
--- a/meson.build
+++ b/meson.build
@@ -620,6 +620,22 @@ if get_option('build_backends')
endif
+ ## ~~~~~~~~
+ ## XLA
+ ## ~~~~~~~~
+ if get_option('xla')
+ files += [
+ 'src/neural/xla/hlo_builder.cc',
+ 'src/neural/xla/network_xla.cc',
+ 'src/neural/xla/onnx2hlo.cc',
+ 'src/neural/xla/print_hlo.cc',
+ 'src/neural/xla/pjrt.cc',
+ 'src/neural/xla/xla_runner.cc',
+ ]
+ files += gen_proto_src.process('src/neural/xla/hlo.proto',
+ preserve_path_from : meson.current_source_dir() + '/src/')
+ deps += cc.find_library('dl', required: false)
+ endif
endif # if get_option('build_backends')
diff --git a/meson_options.txt b/meson_options.txt
index 3a187cac5b..0606fb38d2 100644
--- a/meson_options.txt
+++ b/meson_options.txt
@@ -187,3 +187,8 @@ option('onnx_include',
type: 'string',
value: '',
description: 'Paths to ONNX runtime includes')
+
+option('xla',
+ type: 'boolean',
+ value: false,
+ description: 'Enable XLA backend')
\ No newline at end of file
diff --git a/scripts/compile_proto.py b/scripts/compile_proto.py
index 9163683a4e..cb7d0450b2 100755
--- a/scripts/compile_proto.py
+++ b/scripts/compile_proto.py
@@ -64,6 +64,7 @@
'package',
'message',
'optional',
+ 'required',
'repeated',
'enum',
] + list(TYPES.keys())
diff --git a/src/neural/xla/hlo.proto b/src/neural/xla/hlo.proto
new file mode 100644
index 0000000000..a374b25837
--- /dev/null
+++ b/src/neural/xla/hlo.proto
@@ -0,0 +1,412 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+syntax = "proto2";
+
+package pblczero;
+
+message XlaLayoutProto {
+ // Sequence of dimension numbers, from minor (fastest varying index) to major
+ // (slowest varying index). This field is required.
+ repeated int64 minor_to_major = 1;
+}
+
+message XlaShapeProto {
+ enum Type {
+ PRIMITIVE_TYPE_INVALID = 0;
+
+ // Predicates are two-state booleans.
+ PRED = 1;
+
+ // Signed integral values of fixed width.
+ S4 = 21;
+ S8 = 2;
+ S16 = 3;
+ S32 = 4;
+ S64 = 5;
+
+ // Unsigned integral values of fixed width.
+ U4 = 22;
+ U8 = 6;
+ U16 = 7;
+ U32 = 8;
+ U64 = 9;
+
+ // Floating-point values of fixed width.
+ //
+ // Note: if f16s are not natively supported on the device, they will be
+ // converted to f16 from f32 at arbirary points in the computation.
+ F16 = 10;
+ F32 = 11;
+
+ // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
+ // floating-point format, but uses 1 bit for the sign, 8 bits for the
+ // exponent and 7 bits for the mantissa.
+ BF16 = 16;
+
+ F64 = 12;
+
+ // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2209.05433
+ //
+ // F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the
+ // existing IEEE types.
+ //
+ // F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only
+ // Finite and NaN values are supported. Unlike IEEE types, infinities are
+ // not supported. NaN is represented when the exponent and mantissa bits
+ // are all 1s. All other values are finite.
+ //
+ // F8E4M3B11FNUZ has 4 exponent bits and 3 mantissa bits and a bias of 11.
+ // The "FNUZ" means only Finite and NaN values are supported; zero is
+ // unsigned. Unlike IEEE types, infinities are not supported. NaN is
+ // represented when the exponent and mantissa bits are all 0s with a sign
+ // bit of 1. All other values are finite.
+
+ F8E5M2 = 19;
+ F8E4M3FN = 20;
+ F8E4M3B11FNUZ = 23;
+
+ // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915
+ //
+ // F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits.
+ // F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits.
+ //
+ // The "FNUZ" means only Finite and NaN values are supported; zero is
+ // unsigned. Unlike IEEE types, infinities are not supported. NaN is
+ // represented when the exponent and mantissa bits are all 0s with a sign
+ // bit of 1. All other values are finite.
+ //
+ // These differences mean there's an additional exponent value available. To
+ // keep the same dynamic range as an IEEE-like FP8 type, the exponent is
+ // biased one more than would be expected given the number of exponent bits
+ // (8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ).
+ F8E5M2FNUZ = 24;
+ F8E4M3FNUZ = 25;
+
+ // Complex values of fixed width.
+ C64 = 15; // Paired F32 (real, imag), as in std::complex.
+ C128 = 18; // Paired F64 (real, imag), as in std::complex.
+
+ // A tuple is a polymorphic sequence; e.g. a shape that holds different
+ // sub-shapes. They are used for things like returning multiple values from
+ // a computation; e.g. a computation that returns weights and biases may
+ // have a signature that results in a tuple like (f32[784x2000], f32[2000])
+ //
+ // If a shape proto has the tuple element type, it may not have any entries
+ // in the dimensions field.
+ TUPLE = 13;
+
+ // An opaque type used for passing context-specific data to a custom
+ // operation. Shapes of this primitive type will have empty dimensions and
+ // tuple_shapes fields.
+ //
+ // (OPAQUE would be a better name for this identifier, but that conflicts
+ // with a macro defined in windows.h.)
+ OPAQUE_TYPE = 14;
+
+ // A token type threaded between side-effecting operations. Shapes of this
+ // primitive type will have empty dimensions and tuple_shapes fields.
+ TOKEN = 17;
+ }
+
+ // The element type for this shape.
+ required Type element_type = 2;
+
+ // The size (number of elements) for each dimension, or an upper bound on the
+ // size if the dimension is dynamic. In XLA, dimensions are numbered from 0
+ // to N-1 for an N-dimensional array. The first element of 'dimensions' is the
+ // size of dimension 0, the second element is the size of dimension 1, and so
+ // forth. Empty list indicates a scalar.
+ //
+ // If the respective element in 'is_dimension_dynamic' is true then the value
+ // in this field represents an upper bound on the size of the dimension.
+ repeated int64 dimensions = 3;
+
+ // For tuples only, the shapes of constituent shapes in the tuple sequence.
+ repeated XlaShapeProto tuple_shapes = 4;
+
+ // The layout used to back this shape.
+ required XlaLayoutProto layout = 5;
+
+ // For arrays, this indicates whether or not each dimension is
+ // dynamically-sized. The number of elements in this repeated field should be
+ // zero (indicating that no dimensions are dynamic) or equal to the number of
+ // elements in the 'dimensions' field.
+ repeated bool is_dynamic_dimension = 6;
+}
+
+// Shape of the parameters and output of a computation (like a traditional
+// function signature).
+message XlaProgramShapeProto {
+ repeated XlaShapeProto parameters = 1;
+ required XlaShapeProto result = 2;
+ repeated string parameter_names = 3;
+}
+
+// Symbolization metadata for HLO Instructions.
+//
+// This metadata is used for debugging XLA code generation, as well as
+// performance profiling of XLA-generated executables.
+message XlaOpMetadata {
+ // The framework op name that generated this XLA op.
+ //
+ // Frameworks that build on top of XLA should mirror the names of their ops
+ // back to users by specifying the op_type. In this way, even if the
+ // framework's "ops" are implemented as multiple XLA HLO Ops, they can be
+ // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
+ // multiple ops, then each op should have the op_type be "SoftMax".)
+ optional string op_type = 1;
+ // The user-specified name of the op.
+ //
+ // This name is often unique within a computation. Note: some frameworks
+ // add auto-generated names if the user does not provide one.
+ optional string op_name = 2;
+ // Indicate a file and line that this op is associated to in a user's program.
+ //
+ // e.g. it could be the file and line of user code that generated the op.
+ optional string source_file = 3;
+ optional int32 source_line = 4;
+}
+
+message XlaLiteralProto {
+ required XlaShapeProto shape = 1;
+ repeated bool preds = 2;
+ optional bytes s4s = 21;
+ optional bytes u4s = 22;
+ optional bytes s8s = 15;
+ optional bytes u8s = 3;
+ repeated int32 s32s = 4;
+ repeated int64 s64s = 5;
+ repeated uint32 u32s = 6;
+ repeated uint64 u64s = 7;
+ repeated float f32s = 8;
+ repeated double f64s = 9;
+ repeated float c64s = 12; // Stored as interleaved real, imag floats.
+ repeated double c128s = 18; // Stored as interleaved real, imag doubles.
+ repeated XlaLiteralProto tuple_literals = 10;
+ // The F16s, BF16s, U16s and S16s are encoded in little endian byte order
+ optional bytes f16s = 11;
+ optional bytes bf16s = 13;
+ optional bytes u16s = 16;
+ optional bytes s16s = 17;
+ optional bytes f8e5m2s = 19;
+ optional bytes f8e4m3fns = 20;
+ optional bytes f8e4m3b11fnuzs = 23;
+ optional bytes f8e5m2fnuzs = 24;
+ optional bytes f8e4m3fnuzs = 25;
+ repeated int64 sparse_indices = 14;
+ // Next = 26
+}
+
+message XlaWindowDimension {
+ optional int64 size = 1;
+ optional int64 stride = 2;
+ optional int64 padding_low = 3;
+ optional int64 padding_high = 4;
+ optional int64 window_dilation = 5;
+ optional int64 base_dilation = 6;
+ optional bool window_reversal = 7;
+}
+
+message XlaWindow { repeated XlaWindowDimension dimensions = 1; }
+
+message XlaConvolutionDimensionNumbers {
+ optional int64 input_batch_dimension = 7;
+ optional int64 input_feature_dimension = 8;
+ repeated int64 input_spatial_dimensions = 11;
+ optional int64 kernel_input_feature_dimension = 3;
+ optional int64 kernel_output_feature_dimension = 4;
+ repeated int64 kernel_spatial_dimensions = 6;
+ optional int64 output_batch_dimension = 9;
+ optional int64 output_feature_dimension = 10;
+ repeated int64 output_spatial_dimensions = 12;
+}
+
+message XlaDotDimensionNumbers {
+ repeated int64 lhs_contracting_dimensions = 1;
+ repeated int64 rhs_contracting_dimensions = 2;
+ repeated int64 lhs_batch_dimensions = 3;
+ repeated int64 rhs_batch_dimensions = 4;
+}
+
+message HloInstructionProto {
+ required string name = 1;
+ required string opcode = 2;
+ required XlaShapeProto shape = 3;
+
+ optional XlaOpMetadata metadata = 7;
+
+ // Literal, only present for kConstant.
+ optional XlaLiteralProto literal = 8;
+
+ // Parameter number is only present for kParameter.
+ optional int64 parameter_number = 9;
+
+ // Index for kGetTupleElement.
+ optional int64 tuple_index = 13;
+
+ // Describes the window in a windowed operation such as convolution.
+ optional XlaWindow window = 15;
+
+ // Describes the dimension numbers used for a convolution.
+ optional XlaConvolutionDimensionNumbers convolution_dimension_numbers = 16;
+
+ optional XlaDotDimensionNumbers dot_dimension_numbers = 30;
+
+ // Dimensions present for some operations that require reshaping or
+ // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
+ repeated int64 dimensions = 14;
+
+ // The id of this instruction.
+ required int64 id = 35;
+
+ repeated int64 operand_ids = 36;
+ repeated int64 called_computation_ids = 38;
+}
+
+message HloComputationProto {
+ required string name = 1;
+
+ // The array of instructions is always in a valid dependency order, where
+ // operands appear before their users.
+ repeated HloInstructionProto instructions = 2;
+ required XlaProgramShapeProto program_shape = 4;
+
+ // The id of this computation.
+ required int64 id = 5;
+
+ // The id of the root of the computation.
+ required int64 root_id = 6;
+}
+
+message HloModuleProto {
+ required string name = 1;
+ required string entry_computation_name = 2;
+ required int64 entry_computation_id = 6;
+
+ // The array of computations is always in a valid dependency order, where
+ // callees appear before their callers.
+ repeated HloComputationProto computations = 3;
+
+ // The host program shape (with layout) of the entry computation.
+ required XlaProgramShapeProto host_program_shape = 4;
+
+ // The id of this module.
+ required int64 id = 5;
+}
+
+message OptionOverrideProto {
+ optional string string_field = 1;
+ optional bool bool_field = 2;
+ optional int64 int_field = 3;
+ optional double double_field = 4;
+}
+
+message CompileEnvOptionProto {
+ required string key = 1;
+ required OptionOverrideProto value = 2;
+}
+
+message ExecutableBuildOptionsProto {
+ // If set, this is the device to build the computation for. Valid
+ // device_ordinal values are: 0 to # of devices - 1. These values are
+ // identical to the device ordinal values used by StreamExecutor. The built
+ // executable will be executable on any device equivalent to the specified
+ // device as determined by Backend::devices_equivalent(). A value of -1
+ // indicates this option has not been set.
+ optional int64 device_ordinal = 1;
+
+ // If set, this specifies the layout of the result of the computation. If not
+ // set, the service will chose the layout of the result. A Shape is used to
+ // store the layout to accommodate tuple result shapes. A value of nullptr
+ // indicates the option has not been set.
+ optional XlaShapeProto result_layout = 2;
+
+ // The number of replicas of this computation that are to be executed.
+ // Defaults to 1.
+ optional int64 num_replicas = 4;
+
+ // The number of partitions in this computation. Defaults to 1.
+ optional int64 num_partitions = 5;
+
+ // Indicates whether to use SPMD (true) or MPMD (false) partitioning when
+ // num_partitions > 1 and XLA is requested to partition the input program.
+ optional bool use_spmd_partitioning = 6;
+
+ // Whether to automatically generate XLA shardings for SPMD partitioner.
+ optional bool use_auto_spmd_partitioning = 7;
+
+ // Whether HLOs should be deduplicated.
+ optional bool deduplicate_hlo = 8;
+
+ // Whether input and output buffers are aliased if the associated parameter is
+ // passed-through XLA modules without being changed.
+ optional bool alias_passthrough_params = 10;
+
+ // By default, XLA builds an executable by invoking standard compilation, i.e.
+ // running Compiler::Compile, or both Compiler::RunHloPasses and
+ // Compiler::RunBackend. When run_backend_only is set to true, XLA builds an
+ // executable by invoking only RunBackend and skip invoking RunHloPasses,
+ // which can be used to compile post-optimizations HLO modules.
+ optional bool run_backend_only = 11;
+
+ // Allows sharding propagation to propagate to the outputs. This changes the
+ // output shape of the computation (which is undesirable), but it can be used
+ // to allow to run partial compilation to determine what would be the output
+ // sharding of a computation if XLA would be allowed to propagate the sharding
+ // which can be used by higher level framework as a way to query intermediate
+ // sharding of operations when multiple computation would be chained and
+ // merged together.
+ // This is a vector of bool, because the user can control (if the output of
+ // the computation is a tuple) which elements of the tuple can have the
+ // sharding substituted and which don't. If only one boolean value is passed
+ // in the vector that's interpreted as the value to be applied for every
+ // single element of the output tuple. One value per element of the tuple
+ // means that each value is attached to one of the output elements.
+ repeated bool allow_spmd_sharding_propagation_to_output = 12;
+
+ // Opaque profile data for any feedback directed optimizations.
+ optional bytes fdo_profile = 14;
+
+ optional int64 device_memory_size = 15;
+
+ // Mesh shape in auto sharding options.
+ repeated int64 auto_spmd_partitioning_mesh_shape = 16;
+
+ // Mesh ids in auto sharding options.
+ repeated int64 auto_spmd_partitioning_mesh_ids = 17;
+}
+
+message CompileOptionsProto {
+ repeated XlaShapeProto argument_layouts = 1;
+ optional bool parameter_is_tupled_arguments = 2;
+ optional ExecutableBuildOptionsProto executable_build_options = 3;
+ optional bool compile_portable_executable = 4;
+ optional int64 profile_version = 5;
+ optional bytes serialized_multi_slice_config = 6;
+ repeated CompileEnvOptionProto env_options = 7;
+}
\ No newline at end of file
diff --git a/src/neural/xla/hlo_builder.cc b/src/neural/xla/hlo_builder.cc
new file mode 100644
index 0000000000..5178197863
--- /dev/null
+++ b/src/neural/xla/hlo_builder.cc
@@ -0,0 +1,312 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+#include "neural/xla/hlo_builder.h"
+
+#include
+
+#include "utils/exception.h"
+#include "utils/logging.h"
+
+namespace lczero {
+
+// Creates an instruction and populates required fields of the
+// HloInstructionProto: result shape, opcode and operands.
+// Appends the instruction to the entry computation.
+pblczero::HloInstructionProto* HloBuilder::MakeInstruction(
+ std::string_view opcode, const pblczero::XlaShapeProto& shape,
+ const std::vector operands) {
+ auto instr = std::make_unique();
+ auto ret = instr.get();
+ ret->set_opcode(opcode);
+ *ret->mutable_shape() = shape;
+ *ret->mutable_metadata() = metadata_;
+ ret->set_id(entry_computation_.size());
+ for (const auto& operand : operands) {
+ ret->add_operand_ids(operand->id());
+ }
+ entry_computation_.push_back(std::move(instr));
+ return ret;
+}
+
+// Creates an elementwise instruction, which always have two operands of the
+// same shape.
+pblczero::HloInstructionProto* HloBuilder::MakeElementwiseInstruction(
+ std::string_view opcode, HloFlow lhs, HloFlow rhs) {
+ if (lhs->shape().dimensions() != rhs->shape().dimensions()) {
+ throw Exception("Elementwise operands must have the same shape");
+ }
+ return MakeInstruction(opcode, lhs->shape(), {lhs, rhs});
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Instructions.
+////////////////////////////////////////////////////////////////////////////
+
+HloFlow HloBuilder::Parameter(const pblczero::XlaShapeProto& shape) {
+ return MakeInstruction("parameter", shape, {});
+}
+
+// Converts the element types while keeping the shape.
+HloFlow HloBuilder::Convert(HloFlow input,
+ const pblczero::XlaShapeProto::Type type) {
+ if (input->shape().element_type() == type) return input;
+ pblczero::XlaShapeProto shape = input->shape();
+ shape.set_element_type(type);
+ return MakeInstruction("convert", shape, {input});
+}
+
+HloFlow HloBuilder::Constant(const pblczero::XlaLiteralProto& literal) {
+ auto* flow = MakeInstruction("constant", literal.shape(), {});
+ *flow->mutable_literal() = literal;
+ return flow;
+}
+
+HloFlow HloBuilder::Convolution(
+ HloFlow input, HloFlow filter, const pblczero::XlaWindow& window,
+ const pblczero::XlaConvolutionDimensionNumbers& dn) {
+ if (input->shape().dimensions_size() != filter->shape().dimensions_size()) {
+ throw Exception(
+ "Convolution input and filter shapes must have the "
+ "same number of dimensions");
+ }
+ pblczero::XlaShapeProto shape = input->shape();
+ auto* out_dims = shape.mutable_dimensions();
+ const auto& in_dims = input->shape().dimensions();
+ const auto& filter_dims = filter->shape().dimensions();
+ (*out_dims)[dn.output_batch_dimension()] =
+ in_dims[dn.input_batch_dimension()];
+ (*out_dims)[dn.output_feature_dimension()] =
+ filter_dims[dn.kernel_output_feature_dimension()];
+ for (size_t i = 0; i < dn.input_spatial_dimensions_size(); ++i) {
+ (*out_dims)[dn.output_spatial_dimensions(i)] =
+ in_dims[dn.input_spatial_dimensions(i)];
+ }
+ auto* flow = MakeInstruction("convolution", shape, {input, filter});
+ *flow->mutable_window() = window;
+ *flow->mutable_convolution_dimension_numbers() = dn;
+ return flow;
+}
+
+HloFlow HloBuilder::Broadcast(
+ HloFlow input, const pblczero::XlaShapeProto& target_shape,
+ const std::vector& broadcast_dimensions) {
+ auto flow = MakeInstruction("broadcast", target_shape, {input});
+ if (broadcast_dimensions.size() != input->shape().dimensions_size()) {
+ throw Exception("Broadcast must have the same size as the input shape");
+ }
+ const auto& input_shape = input->shape();
+ for (size_t i = 0; i < broadcast_dimensions.size(); ++i) {
+ auto dim = broadcast_dimensions[i];
+ const auto& input_dim = input_shape.dimensions(i);
+ if (input_dim != 1 && input_dim != target_shape.dimensions(dim)) {
+ throw Exception(
+ "Broadcast dimension must be 1 or equal to the target shape "
+ "dimension");
+ }
+ flow->add_dimensions(dim);
+ }
+ return flow;
+}
+
+HloFlow HloBuilder::Add(HloFlow lhs, HloFlow rhs) {
+ return MakeElementwiseInstruction("add", lhs, rhs);
+}
+
+HloFlow HloBuilder::Maximum(HloFlow lhs, HloFlow rhs) {
+ return MakeElementwiseInstruction("maximum", lhs, rhs);
+}
+
+HloFlow HloBuilder::Reshape(HloFlow input,
+ const pblczero::XlaShapeProto& new_shape) {
+ if (input->shape().element_type() != new_shape.element_type()) {
+ throw Exception("Reshape must have the same element type");
+ }
+ size_t old_elements = std::accumulate(input->shape().dimensions().begin(),
+ input->shape().dimensions().end(), 1,
+ std::multiplies());
+ size_t new_elements = std::accumulate(new_shape.dimensions().begin(),
+ new_shape.dimensions().end(), 1,
+ std::multiplies());
+ if (old_elements != new_elements) {
+ throw Exception("Reshape must have the same number of elements: " +
+ std::to_string(old_elements) + " vs " +
+ std::to_string(new_elements));
+ }
+ return MakeInstruction("reshape", new_shape, {input});
+}
+
+HloFlow HloBuilder::Dot(HloFlow lhs, HloFlow rhs,
+ const pblczero::XlaDotDimensionNumbers& dn) {
+ pblczero::XlaShapeProto new_shape;
+ if (lhs->shape().element_type() != rhs->shape().element_type()) {
+ throw Exception("Dot operands must have the same element type");
+ }
+ new_shape.set_element_type(lhs->shape().element_type());
+ if (dn.lhs_batch_dimensions_size() != dn.rhs_batch_dimensions_size()) {
+ throw Exception("Dot batch dimensions must have the same size");
+ }
+ for (size_t i = 0; i < dn.lhs_batch_dimensions_size(); ++i) {
+ auto lhs_dim = lhs->shape().dimensions(dn.lhs_batch_dimensions(i));
+ auto rhs_dim = rhs->shape().dimensions(dn.rhs_batch_dimensions(i));
+ if (lhs_dim != rhs_dim) {
+ throw Exception("Dot batch dimensions must have the same size");
+ }
+ new_shape.add_dimensions(lhs_dim);
+ }
+ if (dn.lhs_contracting_dimensions_size() !=
+ dn.rhs_contracting_dimensions_size()) {
+ throw Exception("Dot contracting dimensions must have the same size");
+ }
+ for (size_t i = 0; i < dn.lhs_contracting_dimensions_size(); ++i) {
+ auto lhs_dim = lhs->shape().dimensions(dn.lhs_contracting_dimensions(i));
+ auto rhs_dim = rhs->shape().dimensions(dn.rhs_contracting_dimensions(i));
+ if (lhs_dim != rhs_dim) {
+ throw Exception("Dot contracting dimensions must have the same size");
+ }
+ }
+ // Sorry, github copilot generated the code below (well, above too). Enjoy!
+ for (size_t i = 0; i < lhs->shape().dimensions_size(); ++i) {
+ if (std::find(dn.lhs_batch_dimensions().begin(),
+ dn.lhs_batch_dimensions().end(),
+ i) == dn.lhs_batch_dimensions().end() &&
+ std::find(dn.lhs_contracting_dimensions().begin(),
+ dn.lhs_contracting_dimensions().end(),
+ i) == dn.lhs_contracting_dimensions().end()) {
+ new_shape.add_dimensions(lhs->shape().dimensions(i));
+ }
+ }
+ for (size_t i = 0; i < rhs->shape().dimensions_size(); ++i) {
+ if (std::find(dn.rhs_batch_dimensions().begin(),
+ dn.rhs_batch_dimensions().end(),
+ i) == dn.rhs_batch_dimensions().end() &&
+ std::find(dn.rhs_contracting_dimensions().begin(),
+ dn.rhs_contracting_dimensions().end(),
+ i) == dn.rhs_contracting_dimensions().end()) {
+ new_shape.add_dimensions(rhs->shape().dimensions(i));
+ }
+ }
+ ResetXlaShapeProtoLayout(&new_shape);
+ auto flow = MakeInstruction("dot", new_shape, {lhs, rhs});
+ *flow->mutable_dot_dimension_numbers() = dn;
+ return flow;
+}
+
+HloFlow HloBuilder::Tanh(HloFlow input) {
+ return MakeInstruction("tanh", input->shape(), {input});
+}
+
+HloFlow HloBuilder::Tuple(const std::vector& elements) {
+ pblczero::XlaShapeProto shape;
+ shape.set_element_type(pblczero::XlaShapeProto::TUPLE);
+ for (const auto& element : elements) {
+ *shape.add_tuple_shapes() = element->shape();
+ }
+ return MakeInstruction("tuple", shape, elements);
+}
+
+namespace {
+// Go over all "parameter" instructions of the computation and assign
+// "parameter_number" field with increasing numbers.
+// Normally it's not requiredm but in our case it's simpler.
+// Outputs shapes and instruction names of parameters.
+std::pair, std::vector>
+AssignParameterIndices(const HloComputation& comp) {
+ std::vector parameter_shapes;
+ std::vector parameter_names;
+ size_t idx = 0;
+ for (const auto& instr : comp) {
+ if (instr->opcode() == "parameter") {
+ instr->set_parameter_number(idx++);
+ parameter_shapes.push_back(instr->shape());
+ parameter_names.push_back(std::string(instr->name()));
+ }
+ }
+ return {parameter_shapes, parameter_names};
+}
+
+// Finalizes HloComputationProto (sets name, renumbers parameters, adds
+// computation shape and root instruction).
+pblczero::HloComputationProto MakeComputation(const HloComputation& comp,
+ std::string_view name,
+ size_t id) {
+ pblczero::HloComputationProto ret;
+ ret.set_id(id);
+ ret.set_name(name);
+ auto [shapes, names] = AssignParameterIndices(comp);
+ for (auto& instr : comp) *ret.add_instructions() = *instr;
+ *ret.mutable_program_shape()->mutable_parameters() = shapes;
+ *ret.mutable_program_shape()->mutable_parameter_names() = names;
+ *ret.mutable_program_shape()->mutable_result() = comp.back()->shape();
+ ret.set_root_id(comp.back()->id());
+ return ret;
+}
+} // namespace
+
+// Assigns unique names to all instructions in the module.
+// In StableHLO instructions are allowed to have numeric names, but in XLA HLO
+// they are not, so we use "i"+number.
+void HloBuilder::AssignInstructionNames() {
+ // Every instruction in the module should have an unique name, numeric names
+ // are allowed.
+ size_t idx = 0;
+ for (auto& instr : entry_computation_) {
+ instr->set_name("i" + std::to_string(idx++));
+ }
+ for (auto& [_, comp] : dependent_computations_) {
+ for (auto& instr : *comp.mutable_instructions()) {
+ instr.set_name("i" + std::to_string(idx++));
+ }
+ }
+}
+
+pblczero::HloModuleProto HloBuilder::Build(std::string_view name) {
+ AssignInstructionNames();
+ pblczero::HloModuleProto module;
+ module.set_name(name);
+ module.set_entry_computation_name("main");
+ module.set_entry_computation_id(0);
+ *module.add_computations() = MakeComputation(entry_computation_, "main", 0);
+ for (auto& [name, comp] : dependent_computations_) {
+ *module.add_computations() = comp;
+ }
+ *module.mutable_host_program_shape() = module.computations(0).program_shape();
+ return module;
+}
+
+void ResetXlaShapeProtoLayout(pblczero::XlaShapeProto* shape) {
+ shape->mutable_layout()->mutable_minor_to_major()->clear();
+ shape->mutable_is_dynamic_dimension()->clear();
+
+ for (size_t i = 0; i < shape->dimensions_size(); ++i) {
+ shape->add_is_dynamic_dimension(false);
+ shape->mutable_layout()->add_minor_to_major(shape->dimensions_size() - i -
+ 1);
+ }
+}
+
+} // namespace lczero
\ No newline at end of file
diff --git a/src/neural/xla/hlo_builder.h b/src/neural/xla/hlo_builder.h
new file mode 100644
index 0000000000..b294408843
--- /dev/null
+++ b/src/neural/xla/hlo_builder.h
@@ -0,0 +1,107 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+#include
+#include
+#include
+
+#include "neural/xla/hlo.pb.h"
+#include "utils/logging.h"
+
+namespace lczero {
+
+class HloContext;
+class HloBuilder;
+
+using HloFlow = const pblczero::HloInstructionProto*;
+using HloComputation =
+ std::vector>;
+
+// A builder class for constructing HloModuleProto.
+class HloBuilder {
+ public:
+ // HLO operations.
+ HloFlow Parameter(const pblczero::XlaShapeProto& shape);
+ HloFlow Constant(const pblczero::XlaLiteralProto& literal);
+ HloFlow Convert(HloFlow input, const pblczero::XlaShapeProto::Type type);
+ HloFlow Convolution(
+ HloFlow input, HloFlow filter, const pblczero::XlaWindow& window,
+ const pblczero::XlaConvolutionDimensionNumbers& dimension_numbers);
+ HloFlow Broadcast(HloFlow input, const pblczero::XlaShapeProto& target_shape,
+ const std::vector& broadcast_dimensions);
+ HloFlow Add(HloFlow lhs, HloFlow rhs);
+ HloFlow Maximum(HloFlow lhs, HloFlow rhs);
+ HloFlow Reshape(HloFlow input, const pblczero::XlaShapeProto& new_shape);
+ HloFlow Dot(HloFlow lhs, HloFlow rhs,
+ const pblczero::XlaDotDimensionNumbers& dimension_numbers);
+ HloFlow Tanh(HloFlow input);
+ HloFlow Tuple(const std::vector& elements);
+
+ // Build the HloModuleProto with a given name.
+ pblczero::HloModuleProto Build(std::string_view name);
+
+ private:
+ pblczero::HloInstructionProto* MakeInstruction(
+ std::string_view opcode, const pblczero::XlaShapeProto& shape,
+ const std::vector operands);
+ pblczero::HloInstructionProto* MakeElementwiseInstruction(
+ std::string_view opcode, HloFlow lhs, HloFlow rhs);
+ void AssignInstructionNames();
+
+ HloComputation entry_computation_;
+ std::unordered_map
+ dependent_computations_;
+ pblczero::XlaOpMetadata metadata_;
+ friend class HloContext;
+};
+
+// A context class for annotating parts of the HLO computation with metadata,
+// like original ONNX op, its name, and source file name and line.
+// The class saves the current metadata in constructor and restores it in
+// destructor, making it possible to use it in a scoped way.
+class HloContext {
+ public:
+ HloContext(HloBuilder* builder)
+ : builder_(builder), saved_metadata_(builder->metadata_) {}
+ ~HloContext() { builder_->metadata_ = saved_metadata_; }
+ void SetOpType(std::string_view op_type) const {
+ builder_->metadata_.set_op_type(op_type);
+ }
+ void SetOpName(std::string_view op_name) const {
+ builder_->metadata_.set_op_name(op_name);
+ }
+
+ private:
+ HloBuilder* builder_;
+ pblczero::XlaOpMetadata saved_metadata_;
+};
+
+// A helper function to reset a shape of a layout. Marks all dimensions as
+// non-dynamic, and sets layout to major_to_minor.
+void ResetXlaShapeProtoLayout(pblczero::XlaShapeProto* shape);
+
+} // namespace lczero
\ No newline at end of file
diff --git a/src/neural/xla/network_xla.cc b/src/neural/xla/network_xla.cc
new file mode 100644
index 0000000000..37eec13ee9
--- /dev/null
+++ b/src/neural/xla/network_xla.cc
@@ -0,0 +1,305 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+#include
+
+#include "neural/factory.h"
+#include "neural/network.h"
+#include "neural/onnx/converter.h"
+#include "neural/xla/onnx2hlo.h"
+#include "neural/xla/xla_runner.h"
+#include "utils/bititer.h"
+
+namespace lczero {
+namespace {
+
+class Lc0InputTensor : public XlaTensor {
+ public:
+ Lc0InputTensor(size_t max_batch_size)
+ : max_batch_size_(max_batch_size),
+ // TODO replace with make_unique_for_overwrite() once C++20 is
+ // available.
+ data_(new float[GetTensorByteSizeForBatch(max_batch_size)]),
+ shape_{0, kInputPlanes, 8, 8} {}
+
+ const std::vector& shape() const override { return shape_; }
+ const void* data() const override { return data_.get(); }
+ size_t size() const override { return GetTensorByteSizeForBatch(shape_[0]); }
+ size_t capacity() const override {
+ return GetTensorByteSizeForBatch(max_batch_size_);
+ }
+ pblczero::XlaShapeProto::Type type() const override {
+ return pblczero::XlaShapeProto::F32;
+ }
+
+ // Adds a batch to the tensor and returns a pointer to the start of the its
+ // part in the buffer. Does NOT initialize the data with zeros.
+ float* AddBatch() {
+ assert(size_t(shape_[0]) < max_batch_size_);
+ auto ret = data_.get() + shape_[0] * GetTensorByteSizeForBatch(1);
+ ++shape_[0];
+ return ret;
+ }
+ size_t GetBatchSize() const { return shape_[0]; }
+
+ private:
+ static size_t GetTensorByteSizeForBatch(size_t batch_size) {
+ return kInputPlanes * 8 * 8 * batch_size * sizeof(float);
+ }
+
+ const size_t max_batch_size_;
+ std::unique_ptr data_;
+ std::vector shape_;
+};
+
+class XlaNetwork;
+class XlaComputation : public NetworkComputation {
+ public:
+ XlaComputation(const XlaNetwork* network);
+ void AddInput(InputPlanes&& input) override;
+ int GetBatchSize() const override;
+ void ComputeBlocking() override;
+ float GetQVal(int sample) const override;
+ float GetDVal(int sample) const override;
+ float GetPVal(int sample, int move_id) const override;
+ float GetMVal(int sample) const override;
+
+ private:
+ const XlaNetwork* network_;
+ Lc0InputTensor input_tensor_;
+ std::vector> outputs_;
+};
+
+// Indices of various heads in the HLO output.
+struct XlaNetworkOptions {
+ std::optional output_value_idx;
+ std::optional output_wdl_idx;
+ std::optional output_policy_idx;
+ std::optional output_mlh_idx;
+};
+
+class XlaNetwork : public Network {
+ public:
+ XlaNetwork(std::unique_ptr runner,
+ const XlaNetworkOptions& options,
+ const pblczero::NetworkFormat& format);
+
+ const NetworkCapabilities& GetCapabilities() const override {
+ return capabilities_;
+ }
+ std::unique_ptr NewComputation() override {
+ return std::make_unique(this);
+ }
+ int GetMiniBatchSize() const override {
+ // 32 is the default prefetch size, subtract it so that backend doesn't
+ // crash.
+ // TODO make it better when we have a proper way to query the batch size.
+ return runner_->GetMaxBatchSize() - 32;
+ }
+
+ private:
+ std::unique_ptr runner_;
+ XlaNetworkOptions options_;
+ NetworkCapabilities capabilities_;
+
+ friend class XlaComputation;
+};
+
+XlaComputation::XlaComputation(const XlaNetwork* network)
+ : network_(network), input_tensor_(network->runner_->GetMaxBatchSize()) {}
+
+void XlaComputation::AddInput(InputPlanes&& input) {
+ float* ptr = input_tensor_.AddBatch();
+ memset(ptr, 0, 8 * 8 * kInputPlanes * sizeof(float));
+ for (const auto& plane : input) {
+ for (auto bit : IterateBits(plane.mask)) ptr[bit] = plane.value;
+ ptr += 8 * 8;
+ }
+}
+
+float XlaComputation::GetQVal(int sample) const {
+ if (network_->options_.output_wdl_idx) {
+ const float* data = reinterpret_cast(
+ outputs_[*network_->options_.output_wdl_idx]->data());
+ return data[sample * 3 + 0] - data[sample * 3 + 2];
+ } else {
+ const float* data = reinterpret_cast(
+ outputs_[*network_->options_.output_value_idx]->data());
+ return data[sample];
+ }
+}
+
+float XlaComputation::GetDVal(int sample) const {
+ if (network_->options_.output_wdl_idx) {
+ const float* data = reinterpret_cast(
+ outputs_[*network_->options_.output_wdl_idx]->data());
+ return data[sample * 3 + 1];
+ }
+ return 0.0f;
+}
+
+float XlaComputation::GetPVal(int sample, int move_id) const {
+ const float* data = reinterpret_cast(
+ outputs_[*network_->options_.output_policy_idx]->data());
+ return data[sample * 1858 + move_id];
+}
+
+float XlaComputation::GetMVal(int sample) const {
+ if (network_->options_.output_mlh_idx) {
+ const float* data = reinterpret_cast(
+ outputs_[*network_->options_.output_mlh_idx]->data());
+ return data[sample];
+ }
+ return 0.0f;
+}
+
+int XlaComputation::GetBatchSize() const {
+ return input_tensor_.GetBatchSize();
+}
+
+void XlaComputation::ComputeBlocking() {
+ outputs_ = network_->runner_->ExecuteBlocking({&input_tensor_});
+}
+
+XlaNetwork::XlaNetwork(std::unique_ptr runner,
+ const XlaNetworkOptions& options,
+ const pblczero::NetworkFormat& format)
+ : runner_(std::move(runner)),
+ options_(options),
+ capabilities_{format.input(), format.moves_left()} {}
+
+// Converts ONNX model to HLO (for various batch sizes) and adds them to the
+// XlaRunner.
+XlaNetworkOptions FillXlaRunnerFromOnnx(const pblczero::OnnxModel& onnx_model,
+ XlaRunner* runner,
+ size_t max_batch_size, size_t steps) {
+ pblczero::ModelProto onnx;
+ onnx.ParseFromString(onnx_model.model());
+
+ std::unordered_map constant_to_parameter_idx;
+ std::unordered_map input_to_parameter_idx;
+ std::unordered_map output_to_parameter_idx;
+
+ auto add_tensors = [](const std::vector& tensors,
+ std::unordered_map& map) {
+ for (const auto& tensor : tensors) {
+ auto iter = map.find(tensor.name);
+ if (iter == map.end()) {
+ map[tensor.name] = tensor.param_idx;
+ } else if (iter->second != tensor.param_idx) {
+ throw Exception("Inconsistent index for " + tensor.name);
+ }
+ }
+ };
+
+ for (size_t i = 0; i < steps; ++i) {
+ size_t batch_size = max_batch_size * (i + 1) / steps;
+ CERR << "Building HLO for batch size " << batch_size << "...";
+ auto conversion = ConvertOnnxToHlo(onnx, batch_size, {});
+ add_tensors(conversion.constants, constant_to_parameter_idx);
+ add_tensors(conversion.inputs, input_to_parameter_idx);
+ add_tensors(conversion.outputs, output_to_parameter_idx);
+ runner->AddModule(batch_size, conversion.hlo_module);
+ }
+
+ std::vector> constants;
+ constants.resize(constant_to_parameter_idx.size() +
+ input_to_parameter_idx.size());
+ for (const auto& initializer : onnx.graph().initializer()) {
+ auto iter = constant_to_parameter_idx.find(std::string(initializer.name()));
+ if (iter == constant_to_parameter_idx.end()) continue;
+ auto idx = iter->second;
+ assert(idx < constants.size());
+ constants[idx] = OnnxTensorToXlaTensor(initializer);
+ }
+
+ CERR << "Transferring constants...";
+ runner->SetFrozenInputs(std::move(constants));
+ CERR << "Done.";
+
+ XlaNetworkOptions options;
+ if (input_to_parameter_idx.size() != 1 ||
+ input_to_parameter_idx.begin()->first != onnx_model.input_planes()) {
+ throw Exception("Expected a single input named " +
+ std::string(onnx_model.input_planes()));
+ }
+ if (onnx_model.has_output_value()) {
+ options.output_value_idx =
+ output_to_parameter_idx.at(std::string(onnx_model.output_value()));
+ }
+ if (onnx_model.has_output_wdl()) {
+ options.output_wdl_idx =
+ output_to_parameter_idx.at(std::string(onnx_model.output_wdl()));
+ }
+ if (onnx_model.has_output_policy()) {
+ options.output_policy_idx =
+ output_to_parameter_idx.at(std::string(onnx_model.output_policy()));
+ }
+ if (onnx_model.has_output_mlh()) {
+ options.output_mlh_idx =
+ output_to_parameter_idx.at(std::string(onnx_model.output_mlh()));
+ }
+ return options;
+}
+
+// Makes an XLA network. First converts the weights to ONNX, and then calls
+// FillXlaRunnerFromOnnx to convert them further to HLO and them compile them.
+std::unique_ptr MakeXlaNetwork(const std::optional& w,
+ const OptionsDict& opts) {
+ if (!w) throw Exception("The XLA backend requires a network file.");
+ int device = opts.GetOrDefault("device", 0);
+ // Note: if the plugin_path does NOT contain a slash, it's looked up in the
+ // LD_LIBRARY_PATH (and a few other system defined places). If it does contain
+ // a slash, it's looked up at the exact relative or absolute path.
+ auto runner = std::make_unique(
+ opts.GetOrDefault("plugin_path",
+ "./pjrt_c_api_gpu_plugin.so")
+ .c_str(),
+ device);
+ int max_batch_size = opts.GetOrDefault("max_batch", 739);
+ int steps = opts.GetOrDefault("steps", 13);
+
+ XlaNetworkOptions options;
+ if (w->has_onnx_model()) {
+ options = FillXlaRunnerFromOnnx(w->onnx_model(), runner.get(),
+ max_batch_size, steps);
+ } else {
+ CERR << "Converting weights to ONNX first.";
+ WeightsToOnnxConverterOptions onnx_converter_options;
+ auto converted = ConvertWeightsToOnnx(*w, onnx_converter_options);
+ options = FillXlaRunnerFromOnnx(converted.onnx_model(), runner.get(),
+ max_batch_size, steps);
+ }
+
+ return std::make_unique(std::move(runner), options,
+ w->format().network_format());
+}
+
+REGISTER_NETWORK("xla", MakeXlaNetwork, -34)
+
+} // namespace
+} // namespace lczero
diff --git a/src/neural/xla/onnx2hlo.cc b/src/neural/xla/onnx2hlo.cc
new file mode 100644
index 0000000000..b41906e71d
--- /dev/null
+++ b/src/neural/xla/onnx2hlo.cc
@@ -0,0 +1,532 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+#include "neural/xla/onnx2hlo.h"
+
+#include
+
+#include "neural/onnx/onnx.pb.h"
+#include "neural/xla/hlo.pb.h"
+#include "neural/xla/hlo_builder.h"
+#include "neural/xla/print_hlo.h"
+#include "utils/exception.h"
+
+namespace lczero {
+namespace {
+
+pblczero::XlaShapeProto::Type OnnxTypeToXlaType(
+ const pblczero::TensorProto::DataType& type) {
+ switch (type) {
+ case pblczero::TensorProto::FLOAT:
+ return pblczero::XlaShapeProto::F32;
+ case pblczero::TensorProto::UINT8:
+ return pblczero::XlaShapeProto::U8;
+ case pblczero::TensorProto::INT8:
+ return pblczero::XlaShapeProto::S8;
+ case pblczero::TensorProto::UINT16:
+ return pblczero::XlaShapeProto::U16;
+ case pblczero::TensorProto::INT16:
+ return pblczero::XlaShapeProto::S16;
+ case pblczero::TensorProto::INT32:
+ return pblczero::XlaShapeProto::S32;
+ case pblczero::TensorProto::INT64:
+ return pblczero::XlaShapeProto::S64;
+ case pblczero::TensorProto::BOOL:
+ return pblczero::XlaShapeProto::PRED;
+ case pblczero::TensorProto::FLOAT16:
+ return pblczero::XlaShapeProto::F16;
+ case pblczero::TensorProto::DOUBLE:
+ return pblczero::XlaShapeProto::F64;
+ case pblczero::TensorProto::UINT32:
+ return pblczero::XlaShapeProto::U32;
+ case pblczero::TensorProto::UINT64:
+ return pblczero::XlaShapeProto::U64;
+ case pblczero::TensorProto::COMPLEX64:
+ return pblczero::XlaShapeProto::C64;
+ case pblczero::TensorProto::COMPLEX128:
+ return pblczero::XlaShapeProto::C128;
+ case pblczero::TensorProto::BFLOAT16:
+ return pblczero::XlaShapeProto::BF16;
+ case pblczero::TensorProto::STRING:
+ default:
+ throw Exception("Unsupported ONNX type " +
+ pblczero::TensorProto::DataType_Name(type));
+ }
+}
+
+// Converts an ONNX shape to an XLA shape, replacing the batch dimension with
+// the provided batch size.
+pblczero::XlaShapeProto OnnxShapeToXlaShape(const pblczero::TypeProto& type,
+ std::optional batch_size) {
+ pblczero::XlaShapeProto shape;
+ shape.set_element_type(OnnxTypeToXlaType(type.tensor_type().elem_type()));
+ for (const auto& dim : type.tensor_type().shape().dim()) {
+ if (dim.has_dim_value()) {
+ shape.add_dimensions(dim.dim_value());
+ continue;
+ }
+ if (dim.dim_param() == "batch") {
+ if (batch_size.has_value()) {
+ shape.add_dimensions(batch_size.value());
+ continue;
+ }
+ throw Exception("Batch size not provided");
+ }
+ throw Exception("Unsupported dimension type " + type.OutputAsJson());
+ }
+ ResetXlaShapeProtoLayout(&shape);
+ return shape;
+}
+
+// Type is not a field of the ONNX tensor, so this function extracts the shape
+// and converts it (discarding the data).
+pblczero::XlaShapeProto OnnxTensorToXlaShape(
+ const pblczero::TensorProto& tensor) {
+ pblczero::TypeProto type;
+ type.mutable_tensor_type()->set_elem_type(tensor.data_type());
+ for (const auto& dim : tensor.dims()) {
+ type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim);
+ }
+ return OnnxShapeToXlaShape(type, std::nullopt);
+}
+
+// Converts an ONNX tensor to an XLA literal (which is a shape and a data).
+pblczero::XlaLiteralProto OnnxTensorToXlaLiteral(
+ const pblczero::TensorProto& tensor) {
+ pblczero::XlaLiteralProto literal;
+ *literal.mutable_shape() = OnnxTensorToXlaShape(tensor);
+
+ auto convert = [&](std::string_view src, /*std::vector*/ auto* dst) {
+ using value_type =
+ typename std::remove_pointer::type::value_type;
+ dst->assign(reinterpret_cast(src.data()),
+ reinterpret_cast(src.data() + src.size()));
+ };
+
+ switch (tensor.data_type()) {
+ case pblczero::TensorProto::FLOAT:
+ convert(tensor.raw_data(), literal.mutable_f32s());
+ break;
+ case pblczero::TensorProto::INT64:
+ convert(tensor.raw_data(), literal.mutable_s64s());
+ break;
+ default:
+ throw Exception("Cannot convert ONNX tensor to XLA literal for type " +
+ pblczero::XlaShapeProto::Type_Name(
+ OnnxTypeToXlaType(tensor.data_type())));
+ }
+ return literal;
+}
+
+class Onnx2HloConverter {
+ public:
+ Onnx2HloConverter(const Onnx2HloOptions& options) : options_(options) {
+ onnx_op_to_builder_["Add"] = &Onnx2HloConverter::OpAdd;
+ onnx_op_to_builder_["Conv"] = &Onnx2HloConverter::OpConv;
+ onnx_op_to_builder_["MatMul"] = &Onnx2HloConverter::OpMatMul;
+ onnx_op_to_builder_["Relu"] = &Onnx2HloConverter::OpRelu;
+ onnx_op_to_builder_["Reshape"] = &Onnx2HloConverter::OpReshape;
+ onnx_op_to_builder_["Tanh"] = &Onnx2HloConverter::OpTanh;
+ }
+
+ Onnx2HloResult Convert(const pblczero::ModelProto& onnx_model,
+ size_t minibatch_size) {
+ batch_size_ = minibatch_size;
+ // Populate the set of ONNX initializers (constants), but not emit them for
+ // now. They are emitted lazily so that they appear close to the first use.
+ BuildInitializerMapping(onnx_model);
+ // Convert ONNX inputs to HLO parameters.
+ BuildInputs(onnx_model.graph().input());
+ BuildGraph(onnx_model.graph());
+ Onnx2HloResult result;
+ // Convert ONNX outputs to HLO result.
+ result.outputs = BuildOutputs(onnx_model.graph().output());
+ result.hlo_module = builder_.Build("onnx_model");
+ for (size_t i = 0; i < params_.size(); ++i) {
+ const auto& param = params_[i];
+ auto& dst = param.is_constant ? result.constants : result.inputs;
+ dst.push_back({i, param.name, param.flow->shape()});
+ }
+ // PrettyPrintHlo(result.hlo_module, {}, std::cout);
+ return result;
+ }
+
+ private:
+ std::vector BuildOutputs(
+ const std::vector& graph_output) {
+ // Gathers outputs into the root tuple, optionally converting their type if
+ // I/O type is different from the instruction output.
+ std::vector result;
+ std::vector outputs;
+ for (size_t i = 0; i < graph_output.size(); ++i) {
+ const auto& output = graph_output[i];
+ auto flow = GetFlowByName(std::string(output.name()));
+ if (flow->shape().element_type() != options_.io_type) {
+ auto ctx = HloContext(&builder_);
+ ctx.SetOpType("output");
+ ctx.SetOpName(output.name());
+ flow = builder_.Convert(flow, options_.io_type);
+ }
+ result.push_back({i, std::string(output.name()), flow->shape()});
+ outputs.push_back(flow);
+ }
+ builder_.Tuple(outputs);
+ return result;
+ }
+
+ void BuildInitializerMapping(const pblczero::ModelProto& onnx_model) {
+ for (const auto& tensor : onnx_model.graph().initializer()) {
+ initializers_[std::string(tensor.name())] = &tensor;
+ }
+ }
+
+ // Checks that the ONNX node doesn't have any unknown attributes.
+ void CheckKnownAttributes(
+ const pblczero::NodeProto& node,
+ const std::initializer_list attributes) {
+ for (const auto& attribute : node.attribute()) {
+ if (std::find(attributes.begin(), attributes.end(), attribute.name()) ==
+ attributes.end()) {
+ throw Exception("Unknown attribute " + std::string(attribute.name()));
+ }
+ }
+ }
+
+ // Fetches an HloFlow by name. If the name is not in the map, check whether
+ // there is an initializer for it, and either create a constant or a parameter
+ // depending on its size.
+ HloFlow GetFlowByName(const std::string& name) {
+ auto iter = onnx_name_to_hlo_flow_.find(name);
+ if (iter != onnx_name_to_hlo_flow_.end()) return iter->second;
+
+ auto iter2 = initializers_.find(name);
+ if (iter2 == initializers_.end()) {
+ throw Exception("Unknown input " + name);
+ }
+ auto ctx = HloContext(&builder_);
+ ctx.SetOpType("initializer");
+ ctx.SetOpName(name);
+
+ HloFlow flow = nullptr;
+ if (iter2->second->raw_data().size() <= options_.max_inline_constant_size) {
+ flow = builder_.Constant(OnnxTensorToXlaLiteral(*iter2->second));
+ } else {
+ const auto shape = OnnxTensorToXlaShape(*iter2->second);
+ flow = MakeParameter(name, shape, true);
+ }
+ onnx_name_to_hlo_flow_[name] = flow;
+ return flow;
+ }
+
+ // A helper function to fetch an input of ONNX node by index.
+ HloFlow GetInput(const pblczero::NodeProto& node, size_t idx,
+ bool optional = false) {
+ if (idx >= node.input_size()) {
+ if (optional) return nullptr;
+ throw Exception("Input " + std::to_string(idx) + " not set");
+ }
+ return GetFlowByName(std::string(node.input(idx)));
+ }
+
+ // A helper function to fetch an attribute of ONNX node by name.
+ const pblczero::AttributeProto* GetAttribute(const pblczero::NodeProto& node,
+ std::string_view name,
+ bool optional = false) {
+ for (const auto& attribute : node.attribute()) {
+ if (attribute.name() == name) return &attribute;
+ }
+ if (optional) return nullptr;
+ throw Exception("Attribute " + std::string(name) + " not set");
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // ONNX operations
+ /////////////////////////////////////////////////////////////////////////////
+
+ std::vector OpConv(const pblczero::NodeProto& node) {
+ CheckKnownAttributes(node, {"pads", "kernel_shape"});
+ auto* input = GetInput(node, 0);
+ auto* kernel = GetInput(node, 1);
+ auto* bias = GetInput(node, 2, true);
+
+ pblczero::XlaConvolutionDimensionNumbers dn;
+ dn.set_input_batch_dimension(0);
+ dn.set_input_feature_dimension(1);
+ dn.set_kernel_input_feature_dimension(1);
+ dn.set_kernel_output_feature_dimension(0);
+ dn.set_output_batch_dimension(0);
+ dn.set_output_feature_dimension(1);
+ const size_t num_dims = input->shape().dimensions_size() - 2;
+ for (size_t i = 0; i < num_dims; ++i) {
+ dn.add_input_spatial_dimensions(i + 2);
+ dn.add_kernel_spatial_dimensions(i + 2);
+ dn.add_output_spatial_dimensions(i + 2);
+ }
+
+ const auto* pads = GetAttribute(node, "pads");
+ const auto* kernel_shape = GetAttribute(node, "kernel_shape");
+ if (!pads || pads->ints_size() != 2 * num_dims) {
+ throw Exception("'pads' attribute not set or wrong size");
+ }
+ if (!kernel_shape || kernel_shape->ints_size() != num_dims) {
+ throw Exception("'kernel_shape' attribute not set or wrong size");
+ }
+ pblczero::XlaWindow window;
+ for (size_t i = 0; i < input->shape().dimensions_size() - 2; ++i) {
+ auto* dim = window.add_dimensions();
+ dim->set_size(kernel_shape->ints(i));
+ dim->set_stride(1);
+ dim->set_padding_low(pads->ints(i));
+ dim->set_padding_high(pads->ints(i + num_dims));
+ dim->set_window_dilation(1);
+ dim->set_base_dilation(1);
+ }
+
+ auto* conv = builder_.Convolution(input, kernel, window, dn);
+
+ if (!bias) return {conv};
+ auto* flow = builder_.Broadcast(bias, conv->shape(), {1});
+ return {builder_.Add(conv, flow)};
+ }
+
+ std::vector OpRelu(const pblczero::NodeProto& node) {
+ CheckKnownAttributes(node, {});
+ auto* input = GetInput(node, 0);
+ auto* zero = MakeScalar(0, input->shape().element_type());
+ zero = builder_.Broadcast(zero, input->shape(), {});
+ return {builder_.Maximum(input, zero)};
+ }
+
+ std::vector OpTanh(const pblczero::NodeProto& node) {
+ CheckKnownAttributes(node, {});
+ auto* input = GetInput(node, 0);
+ return {builder_.Tanh(input)};
+ }
+
+ std::vector OpAdd(const pblczero::NodeProto& node) {
+ CheckKnownAttributes(node, {});
+ auto* lhs = GetInput(node, 0);
+ auto* rhs = GetInput(node, 1);
+ std::tie(lhs, rhs) = EqualizeShape(lhs, rhs);
+ return {builder_.Add(lhs, rhs)};
+ }
+
+ std::vector OpReshape(const pblczero::NodeProto& node) {
+ CheckKnownAttributes(node, {});
+ auto* input = GetInput(node, 0);
+ if (node.input_size() < 2) {
+ throw Exception("Reshape requires a shape input");
+ }
+ auto dims_tensor = initializers_.find(std::string(node.input(1)));
+ if (dims_tensor == initializers_.end()) {
+ throw Exception("Reshape only supports constant shape");
+ }
+ auto new_dims = OnnxTensorToXlaLiteral(*dims_tensor->second).s64s();
+ pblczero::XlaShapeProto new_shape;
+ new_shape.set_element_type(input->shape().element_type());
+ for (size_t i = 0; i < new_dims.size(); ++i) {
+ auto dim = new_dims[i];
+ if (dim == -1) dim = batch_size_;
+ if (dim == 0) {
+ if (new_dims.size() != input->shape().dimensions_size()) {
+ throw Exception("Reshape cannot infer shape when rank changes");
+ }
+ dim = input->shape().dimensions(i);
+ }
+ new_shape.add_dimensions(dim);
+ }
+ ResetXlaShapeProtoLayout(&new_shape);
+ return {builder_.Reshape(input, new_shape)};
+ }
+
+ std::vector OpMatMul(const pblczero::NodeProto& node) {
+ CheckKnownAttributes(node, {});
+ auto* lhs = GetInput(node, 0);
+ auto* rhs = GetInput(node, 1);
+ if (lhs->shape().dimensions_size() != 2 ||
+ rhs->shape().dimensions_size() != 2) {
+ throw Exception("MatMul only implemented for 2D inputs so far");
+ }
+ pblczero::XlaDotDimensionNumbers dn;
+ dn.add_lhs_contracting_dimensions(1);
+ dn.add_rhs_contracting_dimensions(0);
+ return {builder_.Dot(lhs, rhs, dn)};
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ // Makes a scalar constant (usually 0 or 1) of a given type.
+ template
+ HloFlow MakeScalar(T value, pblczero::XlaShapeProto::Type type) {
+ pblczero::XlaLiteralProto literal;
+ literal.mutable_shape()->set_element_type(type);
+ literal.mutable_shape()->mutable_layout();
+ switch (type) {
+ case pblczero::XlaShapeProto::F32:
+ literal.add_f32s(value);
+ break;
+ default:
+ throw Exception("Unsupported type for zero constant");
+ }
+ return builder_.Constant(literal);
+ }
+
+ // Take two inputs and optionally performs numpy-style broadcasting to make
+ // them equal shape.
+ std::pair EqualizeShape(HloFlow lhs, HloFlow rhs) {
+ const auto& lhs_dims = lhs->shape().dimensions();
+ const auto& rhs_dims = rhs->shape().dimensions();
+
+ const size_t num_dims = std::max(lhs_dims.size(), rhs_dims.size());
+ std::vector output_dims(num_dims);
+ std::vector lhs_broadcast_dims;
+ std::vector rhs_broadcast_dims;
+ bool lhs_broadcast = lhs_dims.size() < num_dims;
+ bool rhs_broadcast = rhs_dims.size() < num_dims;
+
+ for (size_t i = 0; i < num_dims; ++i) {
+ int lhs_idx = i + lhs_dims.size() - num_dims;
+ int rhs_idx = i + rhs_dims.size() - num_dims;
+ const auto lhs_dim = (lhs_idx < 0) ? 1 : lhs_dims[lhs_idx];
+ const auto rhs_dim = (rhs_idx < 0) ? 1 : rhs_dims[rhs_idx];
+ if (lhs_dim != rhs_dim) {
+ if (lhs_dim != 1 && rhs_dim != 1) {
+ throw Exception("Incompatible shapes for broadcast");
+ }
+ if (lhs_dim == 1) lhs_broadcast = true;
+ if (rhs_dim == 1) rhs_broadcast = true;
+ }
+ if (lhs_idx >= 0) lhs_broadcast_dims.push_back(i);
+ if (rhs_idx >= 0) rhs_broadcast_dims.push_back(i);
+ }
+
+ if (lhs_broadcast) {
+ lhs = builder_.Broadcast(lhs, rhs->shape(), lhs_broadcast_dims);
+ }
+ if (rhs_broadcast) {
+ rhs = builder_.Broadcast(rhs, lhs->shape(), rhs_broadcast_dims);
+ }
+ return {lhs, rhs};
+ }
+
+ // Convert ONNX inputs to HLO parameters.
+ void BuildInputs(const std::vector& inputs) {
+ for (const auto& input : inputs) {
+ auto ctx = HloContext(&builder_);
+ ctx.SetOpType("input");
+ ctx.SetOpName(input.name());
+ auto out_shape = OnnxShapeToXlaShape(input.type(), batch_size_);
+ auto in_shape = out_shape;
+ in_shape.set_element_type(options_.io_type);
+ const auto* flow =
+ MakeParameter(std::string(input.name()), in_shape, false);
+ flow = builder_.Convert(flow, out_shape.element_type());
+ onnx_name_to_hlo_flow_[std::string(input.name())] = flow;
+ }
+ }
+
+ // Makes a parameter instruction (for inputs or large constants).
+ HloFlow MakeParameter(const std::string& name,
+ const pblczero::XlaShapeProto& shape,
+ bool is_constant) {
+ auto* res = builder_.Parameter(shape);
+ params_.push_back({name, res, is_constant});
+ return res;
+ }
+
+ void BuildGraph(const pblczero::GraphProto& graph) {
+ for (const auto& node : graph.node()) {
+ // Set up the context so that nodes have metadata from the original ONNX.
+ auto ctx = HloContext(&builder_);
+ ctx.SetOpType(node.op_type());
+ ctx.SetOpName(node.name());
+ DispatchNode(node);
+ }
+ }
+
+ // Calls the correct function to handle the ONNX node, and stores output in
+ // the map.
+ void DispatchNode(const pblczero::NodeProto& node) {
+ auto iter = onnx_op_to_builder_.find(std::string(node.op_type()));
+ if (iter == onnx_op_to_builder_.end()) {
+ throw Exception("Unsupported ONNX op[" + std::string(node.op_type()) +
+ "] name=[" + std::string(node.name()) + "]");
+ }
+ try {
+ auto outputs = (this->*iter->second)(node);
+ if (outputs.size() != node.output_size()) {
+ throw Exception("Node produced wrong number of outputs");
+ }
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ onnx_name_to_hlo_flow_[std::string(node.output(i))] = outputs[i];
+ }
+ } catch (Exception& e) {
+ throw Exception("Error in ONNX op[" + std::string(node.op_type()) +
+ "] name=[" + std::string(node.name()) + "]: " + e.what());
+ }
+ }
+
+ std::unordered_map onnx_name_to_hlo_flow_;
+ std::unordered_map (Onnx2HloConverter::*)(
+ const pblczero::NodeProto&)>
+ onnx_op_to_builder_;
+ std::unordered_map initializers_;
+ HloBuilder builder_;
+ size_t batch_size_ = 0;
+ Onnx2HloOptions options_;
+ struct Param {
+ std::string name;
+ HloFlow flow;
+ bool is_constant;
+ };
+ std::vector params_;
+};
+
+} // namespace
+
+Onnx2HloResult ConvertOnnxToHlo(const pblczero::ModelProto& onnx_model,
+ size_t minibatch_size,
+ const Onnx2HloOptions& options) {
+ Onnx2HloConverter converter(options);
+ return converter.Convert(onnx_model, minibatch_size);
+}
+
+std::unique_ptr OnnxTensorToXlaTensor(
+ const pblczero::TensorProto& onnx_tensor) {
+ switch (onnx_tensor.data_type()) {
+ case pblczero::TensorProto::FLOAT:
+ return std::make_unique(onnx_tensor.dims(),
+ onnx_tensor.raw_data(),
+ pblczero::XlaShapeProto::F32);
+ default:
+ throw Exception(
+ "Unsupported ONNX tensor type for buffer conversion " +
+ pblczero::TensorProto::DataType_Name(onnx_tensor.data_type()));
+ }
+}
+
+} // namespace lczero
\ No newline at end of file
diff --git a/src/neural/xla/onnx2hlo.h b/src/neural/xla/onnx2hlo.h
new file mode 100644
index 0000000000..284f08d1bc
--- /dev/null
+++ b/src/neural/xla/onnx2hlo.h
@@ -0,0 +1,73 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+#pragma once
+
+#include
+#include
+
+#include "neural/onnx/onnx.pb.h"
+#include "neural/xla/hlo.pb.h"
+#include "neural/xla/xla_runner.h"
+
+namespace lczero {
+
+struct Onnx2HloOptions {
+ // Constants larger that this size in bytes will be passed as parameters
+ // instead. This allows them to be shared between different modules.
+ size_t max_inline_constant_size = 1024;
+ // The types of input/output tensors (does not affect constants passed as
+ // parameters).
+ pblczero::XlaShapeProto::Type io_type = pblczero::XlaShapeProto::F32;
+};
+
+struct Onnx2HloResult {
+ struct NamedTensor {
+ // Index of the tensor in the input or output tuple.
+ size_t param_idx;
+ // Name of the tensor from the ONNX model.
+ std::string name;
+ pblczero::XlaShapeProto shape;
+ };
+ // Constants that are passed as inputs to the module.
+ std::vector constants;
+ std::vector inputs;
+ std::vector outputs;
+ pblczero::HloModuleProto hlo_module;
+};
+
+// Converts an ONNX model to an HLO module.
+Onnx2HloResult ConvertOnnxToHlo(const pblczero::ModelProto& onnx_model,
+ size_t minibatch_size,
+ const Onnx2HloOptions& options);
+
+// Converts an ONNX tensor to an XLA tensor (thanks GitHub Copilot for the
+// comment suggestion).
+std::unique_ptr OnnxTensorToXlaTensor(
+ const pblczero::TensorProto& onnx_tensor);
+
+} // namespace lczero
diff --git a/src/neural/xla/pjrt.cc b/src/neural/xla/pjrt.cc
new file mode 100644
index 0000000000..745a2d40f3
--- /dev/null
+++ b/src/neural/xla/pjrt.cc
@@ -0,0 +1,403 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+#include "pjrt.h"
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include "pjrt_c_api.h"
+#include "utils/logging.h"
+
+namespace lczero {
+
+namespace {
+static std::string value_to_string(const std::string& value) { return value; }
+static std::string value_to_string(int64_t value) {
+ return std::to_string(value);
+}
+static std::string value_to_string(const std::vector& value) {
+ std::string result;
+ for (auto v : value) {
+ if (!result.empty()) result += ", ";
+ result += std::to_string(v);
+ }
+ return result;
+}
+static std::string value_to_string(float value) {
+ return std::to_string(value);
+}
+static std::string value_to_string(bool value) {
+ return value ? "true" : "false";
+}
+
+template
+T MakeStruct() {
+ T t;
+ memset(&t, 0, sizeof(t));
+ t.struct_size = sizeof(t);
+ return t;
+}
+
+PJRT_Error_Code GetErrorCode(const PJRT_Api* api, PJRT_Error* error) {
+ auto args = MakeStruct();
+ args.error = error;
+ api->PJRT_Error_GetCode(&args);
+ return args.code;
+}
+
+} // namespace
+
+std::string PjrtKeyValue::value_as_string() const {
+ return std::visit([&](const auto& arg) { return value_to_string(arg); },
+ value_);
+}
+
+PjrtKeyValue MakeKeyValue(const PJRT_NamedValue* kv) {
+ PjrtKeyValue result;
+ result.set_key({kv->name, kv->name_size});
+ switch (kv->type) {
+ case PJRT_NamedValue_kString:
+ result.set_value(std::string(kv->string_value, kv->value_size));
+ break;
+ case PJRT_NamedValue_kInt64:
+ result.set_value(kv->int64_value);
+ break;
+ case PJRT_NamedValue_kInt64List:
+ result.set_value(std::vector(
+ kv->int64_array_value, kv->int64_array_value + kv->value_size));
+ break;
+ case PJRT_NamedValue_kFloat:
+ result.set_value(kv->float_value);
+ break;
+ case PJRT_NamedValue_kBool:
+ result.set_value(kv->bool_value);
+ break;
+ }
+ return result;
+}
+
+std::string PjrtCommon::GetErrorMessage(PJRT_Error* error) const {
+ auto args = MakeStruct();
+ args.error = error;
+ api_->PJRT_Error_Message(&args);
+ return std::string(args.message, args.message_size);
+}
+void PjrtCommon::DestroyErrorMessage(PJRT_Error* error) const {
+ assert(error);
+ auto args = MakeStruct();
+ args.error = error;
+ api_->PJRT_Error_Destroy(&args);
+}
+
+void PjrtCommon::CheckError(PJRT_Error* error) const {
+ if (!error) return;
+ PjrtException exception(static_cast(GetErrorCode(api_, error)),
+ GetErrorMessage(error));
+ DestroyErrorMessage(error);
+ throw exception;
+}
+
+PjrtExecutable::PjrtExecutable(const PJRT_Api* api,
+ PJRT_LoadedExecutable* executable)
+ : PjrtCommon(api), executable_(executable) {
+ auto args = MakeStruct();
+ args.loaded_executable = executable_;
+ CheckError(api_->PJRT_LoadedExecutable_GetExecutable(&args));
+
+ auto args2 = MakeStruct();
+ args2.executable = args.executable;
+ CheckError(api_->PJRT_Executable_NumOutputs(&args2));
+ num_outputs_ = args2.num_outputs;
+}
+
+PjrtExecutable::~PjrtExecutable() {
+ auto args = MakeStruct();
+ args.executable = executable_;
+ CheckError(api_->PJRT_LoadedExecutable_Destroy(&args));
+}
+
+size_t PjrtExecutable::GetNumOutputs() const { return num_outputs_; }
+
+std::vector> PjrtExecutable::ExecuteBlocking(
+ const std::vector& inputs) {
+ auto options = MakeStruct();
+ options.num_non_donatable_input_indices = inputs.size();
+ std::vector non_donatable_indices(inputs.size());
+ // TODO the buffer 0 is actually donatable.
+ std::iota(non_donatable_indices.begin(), non_donatable_indices.end(), 0);
+ options.non_donatable_input_indices = non_donatable_indices.data();
+
+ auto args = MakeStruct();
+ args.executable = executable_;
+ args.options = &options;
+ args.num_devices = 1;
+ std::vector buffers(inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) buffers[i] = inputs[i]->buffer_;
+ PJRT_Buffer* const* buffers_ptr = buffers.data();
+ args.num_args = inputs.size();
+ args.argument_lists = &buffers_ptr;
+
+ std::vector outputs(num_outputs_);
+ PJRT_Buffer** outputs_ptr = outputs.data();
+ PJRT_Event* event_ptr;
+ args.output_lists = &outputs_ptr;
+ args.device_complete_events = &event_ptr;
+ CheckError(api_->PJRT_LoadedExecutable_Execute(&args));
+
+ PjrtEvent event(api_, event_ptr);
+ event.Await();
+
+ std::vector> output_buffers;
+ output_buffers.reserve(num_outputs_);
+ for (size_t i = 0; i < num_outputs_; ++i) {
+ output_buffers.push_back(
+ std::make_unique(api_, outputs[i]));
+ }
+ return output_buffers;
+}
+
+PjrtDevice::PjrtDevice(const PJRT_Api* api, PJRT_Device* device)
+ : PjrtCommon(api), device_(device) {
+ auto args = MakeStruct();
+ args.device = device_;
+ CheckError(api_->PJRT_Device_GetDescription(&args));
+ description_ = args.device_description;
+}
+
+std::string PjrtDevice::ToString() const {
+ auto args = MakeStruct();
+ args.device_description = description_;
+ CheckError(api_->PJRT_DeviceDescription_ToString(&args));
+ return {args.to_string, args.to_string_size};
+}
+
+PjrtClient::PjrtClient(const PJRT_Api* api, PJRT_Client* client)
+ : PjrtCommon(api), client_(client) {}
+
+PjrtClient::~PjrtClient() {
+ auto args = MakeStruct();
+ args.client = client_;
+ CheckError(api_->PJRT_Client_Destroy(&args));
+}
+
+std::unique_ptr PjrtClient::CompileHlo(
+ std::string_view hlo, std::string_view config) {
+ constexpr std::string_view kFormat = "hlo";
+ auto program = MakeStruct();
+ program.code = const_cast(hlo.data());
+ program.code_size = hlo.size();
+ program.format = kFormat.data();
+ program.format_size = kFormat.size();
+
+ auto args = MakeStruct();
+ args.client = client_;
+ args.program = &program;
+ args.compile_options = const_cast(config.data());
+ args.compile_options_size = config.size();
+ CheckError(api_->PJRT_Client_Compile(&args));
+ return std::make_unique(api_, args.executable);
+}
+
+std::vector> PjrtClient::GetDevices() {
+ auto args = MakeStruct();
+ args.client = client_;
+ CheckError(api_->PJRT_Client_Devices(&args));
+ std::vector> result;
+ result.reserve(args.num_devices);
+ for (size_t i = 0; i < args.num_devices; ++i) {
+ result.push_back(std::make_unique(api_, args.devices[i]));
+ }
+ return result;
+}
+
+PjrtEvent::PjrtEvent(const PJRT_Api* api, PJRT_Event* event)
+ : PjrtCommon(api), event_(event) {}
+
+PjrtEvent::~PjrtEvent() {
+ auto args = MakeStruct();
+ args.event = event_;
+ CheckError(api_->PJRT_Event_Destroy(&args));
+}
+
+void PjrtEvent::Await() {
+ auto args = MakeStruct();
+ args.event = event_;
+ CheckError(api_->PJRT_Event_Await(&args));
+}
+
+PjrtDeviceBuffer::PjrtDeviceBuffer(const PJRT_Api* api, PJRT_Buffer* buffer)
+ : PjrtCommon(api), buffer_(buffer) {}
+
+PjrtDeviceBuffer::~PjrtDeviceBuffer() {
+ auto args = MakeStruct();
+ args.buffer = buffer_;
+ CheckError(api_->PJRT_Buffer_Destroy(&args));
+}
+
+size_t PjrtDeviceBuffer::GetSize() const {
+ auto args = MakeStruct();
+ args.src = buffer_;
+ CheckError(api_->PJRT_Buffer_ToHostBuffer(&args));
+ return args.dst_size;
+}
+
+PjrtType PjrtDeviceBuffer::GetType() const {
+ auto args = MakeStruct();
+ args.buffer = buffer_;
+ CheckError(api_->PJRT_Buffer_ElementType(&args));
+ return static_cast(args.type);
+}
+
+std::vector PjrtDeviceBuffer::GetDimensions() const {
+ auto args = MakeStruct();
+ args.buffer = buffer_;
+ CheckError(api_->PJRT_Buffer_Dimensions(&args));
+ return {args.dims, args.dims + args.num_dims};
+}
+
+std::unique_ptr PjrtDeviceBuffer::DeviceToHost(void* dst,
+ size_t size) {
+ auto args = MakeStruct();
+ args.src = buffer_;
+ args.dst = dst;
+ args.dst_size = size;
+ CheckError(api_->PJRT_Buffer_ToHostBuffer(&args));
+ return std::make_unique(api_, args.event);
+}
+
+PjrtHostToDeviceTransfer::PjrtHostToDeviceTransfer(
+ const PJRT_Api* api, PJRT_Buffer* buffer, std::unique_ptr event)
+ : PjrtCommon(api), buffer_(buffer), event_(std::move(event)) {}
+
+void PjrtHostToDeviceTransfer::Await() { event_->Await(); }
+
+std::unique_ptr
+PjrtHostToDeviceTransfer::AwaitAndReleaseBuffer() {
+ if (!buffer_) {
+ throw PjrtException(PjrtErrorCode::INVALID_ARGUMENT,
+ "Buffer already released");
+ }
+ Await();
+ auto res = std::make_unique(api_, buffer_);
+ buffer_ = nullptr;
+ return res;
+}
+
+PjrtHostToDeviceTransfer::~PjrtHostToDeviceTransfer() {
+ Await();
+ if (buffer_) {
+ auto args = MakeStruct();
+ args.buffer = buffer_;
+ CheckError(api_->PJRT_Buffer_Destroy(&args));
+ }
+}
+
+std::unique_ptr PjrtClient::HostToDevice(
+ std::string_view buffer, PjrtType type, const std::vector& dims,
+ const PjrtDevice* device) {
+ auto args = MakeStruct();
+ args.client = client_;
+ args.data = buffer.data();
+ args.type = static_cast(type);
+ args.dims = dims.data();
+ args.num_dims = dims.size();
+ args.host_buffer_semantics =
+ PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes;
+ args.device = device->device_;
+ CheckError(api_->PJRT_Client_BufferFromHostBuffer(&args));
+ auto event = std::make_unique(api_, args.done_with_host_buffer);
+ return std::make_unique(api_, args.buffer,
+ std::move(event));
+}
+
+Pjrt::Pjrt(const char* library_path) : PjrtCommon(nullptr) {
+ // TODO factor out the dlopen/dlsym code into a separate function, and
+ // implement for other OSes.
+ void* handle = dlopen(library_path, RTLD_LAZY);
+ if (!handle) {
+ throw PjrtException(PjrtErrorCode::INVALID_ARGUMENT,
+ "Unable to load PJRT library " +
+ std::string(library_path) + ": " + dlerror());
+ }
+ typedef const PJRT_Api* (*PjrtApiFunc)();
+ auto func = reinterpret_cast(dlsym(handle, "GetPjrtApi"));
+ if (!func) {
+ throw PjrtException(PjrtErrorCode::INVALID_ARGUMENT,
+ "Unable to find GetPjrtApi() in PJRT library " +
+ std::string(library_path) + ": " + dlerror());
+ }
+ api_ = func();
+ if (!api_) {
+ throw PjrtException(PjrtErrorCode::INVALID_ARGUMENT,
+ "GetPjrtApi() returned nullptr in PJRT library " +
+ std::string(library_path));
+ }
+ auto [major, minor] = ApiVersion();
+ if (major != PJRT_API_MAJOR || minor < PJRT_API_MINOR) {
+ throw PjrtException(
+ PjrtErrorCode::INVALID_ARGUMENT,
+ "PJRT library " + std::string(library_path) +
+ " has incompatible API version: " + std::to_string(major) + "." +
+ std::to_string(minor) + " vs " + std::to_string(PJRT_API_MAJOR) +
+ "." + std::to_string(PJRT_API_MINOR));
+ }
+ Initialize();
+}
+
+std::vector Pjrt::GetAttributes() const {
+ auto args = MakeStruct();
+ CheckError(api_->PJRT_Plugin_Attributes(&args));
+ std::vector result;
+ result.reserve(args.num_attributes);
+ for (size_t i = 0; i < args.num_attributes; ++i) {
+ result.push_back(MakeKeyValue(args.attributes + i));
+ }
+ return result;
+}
+
+std::unique_ptr Pjrt::CreateClient() {
+ auto args = MakeStruct();
+ CheckError(api_->PJRT_Client_Create(&args));
+ return std::make_unique(api_, args.client);
+}
+
+std::pair Pjrt::ApiVersion() const {
+ return std::make_pair(api_->pjrt_api_version.major_version,
+ api_->pjrt_api_version.minor_version);
+}
+
+void Pjrt::Initialize() {
+ auto args = MakeStruct();
+ CheckError(api_->PJRT_Plugin_Initialize(&args));
+}
+
+} // namespace lczero
\ No newline at end of file
diff --git a/src/neural/xla/pjrt.h b/src/neural/xla/pjrt.h
new file mode 100644
index 0000000000..1fd250e34c
--- /dev/null
+++ b/src/neural/xla/pjrt.h
@@ -0,0 +1,252 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+// This file contains set of C++ wrappers around the PJRT C API.
+
+#pragma once
+
+#include
+#include
+#include
+#include
+
+struct PJRT_Api;
+struct PJRT_Buffer;
+struct PJRT_Client;
+struct PJRT_Device;
+struct PJRT_DeviceDescription;
+struct PJRT_Error;
+struct PJRT_Event;
+struct PJRT_LoadedExecutable;
+
+namespace lczero {
+
+// PJRT_Error_Code as enum class. Coincidentally, the error codes are the same
+// as in absl Status module.
+enum class PjrtErrorCode {
+ CANCELLED = 1,
+ UNKNOWN = 2,
+ INVALID_ARGUMENT = 3,
+ DEADLINE_EXCEEDED = 4,
+ NOT_FOUND = 5,
+ ALREADY_EXISTS = 6,
+ PERMISSION_DENIED = 7,
+ RESOURCE_EXHAUSTED = 8,
+ FAILED_PRECONDITION = 9,
+ ABORTED = 10,
+ OUT_OF_RANGE = 11,
+ UNIMPLEMENTED = 12,
+ INTERNAL = 13,
+ UNAVAILABLE = 14,
+ DATA_LOSS = 15,
+ UNAUTHENTICATED = 16
+};
+
+// PJRT_Type as enum class. Conincidentally, the types are the same as in XLA
+// and HloModuleProto, so simple cast works.
+enum class PjrtType {
+ INVALID,
+ PRED,
+ S8,
+ S16,
+ S32,
+ S64,
+ U8,
+ U16,
+ U32,
+ U64,
+ F16,
+ F32,
+ F64,
+ BF16,
+ C64,
+ C128,
+ F8E5M2,
+ F8E4M3FN,
+ F8E4M3B11FNUZ,
+ F8E5M2FNUZ,
+ F8E4M3FNUZ,
+ S4,
+ U4,
+};
+
+// PJRT errors as exceptions.
+class PjrtException : public std::exception {
+ public:
+ explicit PjrtException(PjrtErrorCode code, const std::string& message)
+ : message_(message), code_(code) {}
+
+ const char* what() const noexcept override { return message_.data(); }
+ PjrtErrorCode code() const { return code_; }
+
+ private:
+ std::string message_;
+ PjrtErrorCode code_;
+};
+
+// PJRT_NamedValue wrapper. PJRT_NamedValue is a string-keyed values that are
+// used for auxiliary functionality like plugin attributes.
+class PjrtKeyValue {
+ public:
+ PjrtKeyValue() = default;
+ PjrtKeyValue(const PjrtKeyValue&) = default;
+ PjrtKeyValue(PjrtKeyValue&&) = default;
+ template
+ PjrtKeyValue(const std::string& k, const T& v) : key_(k), value_(v) {}
+
+ const std::string& key() const { return key_; }
+ // Converts the value to string. This is useful for logging and debugging.
+ std::string value_as_string() const;
+
+ void set_key(const std::string& key) { key_ = key; }
+ void set_value(const std::string& value) { value_ = value; }
+ void set_value(int64_t value) { value_ = value; }
+ void set_value(const std::vector& value) { value_ = value; }
+ void set_value(float value) { value_ = value; }
+ void set_value(bool value) { value_ = value; }
+
+ private:
+ std::string key_;
+ std::variant, float, bool> value_;
+};
+
+// A shared base class for all wrappers. Keeps the API pointer and auxiliary
+// functions like error checking.
+class PjrtCommon {
+ protected:
+ PjrtCommon(const PJRT_Api* api) : api_(api) {}
+ virtual ~PjrtCommon() = default;
+
+ std::string GetErrorMessage(PJRT_Error* error) const;
+ void DestroyErrorMessage(PJRT_Error* error) const;
+ void CheckError(PJRT_Error* error) const;
+
+ const PJRT_Api* api_;
+};
+
+class PjrtDevice : protected PjrtCommon {
+ public:
+ PjrtDevice(const PJRT_Api* api, PJRT_Device* device);
+ std::string ToString() const;
+
+ private:
+ PJRT_Device* device_;
+ PJRT_DeviceDescription* description_;
+ friend class PjrtExecutable;
+ friend class PjrtClient;
+};
+
+// An event for waiting for asynchronous operations.
+class PjrtEvent : protected PjrtCommon {
+ public:
+ PjrtEvent(const PJRT_Api* api, PJRT_Event* event);
+ // Blocks until the operation is complete.
+ void Await();
+ ~PjrtEvent();
+
+ private:
+ PJRT_Event* event_;
+};
+
+// A buffer in the device memory.
+class PjrtDeviceBuffer : protected PjrtCommon {
+ public:
+ PjrtDeviceBuffer(const PJRT_Api* api, PJRT_Buffer* buffer);
+ ~PjrtDeviceBuffer();
+ // Returns the size of the buffer in bytes.
+ size_t GetSize() const;
+ // Starts an asynchronous copy from the device to the host memory.
+ [[nodiscard]] std::unique_ptr DeviceToHost(void* dst, size_t size);
+ PjrtType GetType() const;
+ std::vector GetDimensions() const;
+
+ private:
+ PJRT_Buffer* buffer_;
+ friend class PjrtExecutable;
+};
+
+class PjrtExecutable : protected PjrtCommon {
+ public:
+ PjrtExecutable(const PJRT_Api* api, PJRT_LoadedExecutable* executable);
+ ~PjrtExecutable();
+ // Executes the executable with the given inputs. The inputs are not owned or
+ // modified. The function allocates the output buffers and returns them.
+ std::vector> ExecuteBlocking(
+ const std::vector& inputs);
+ size_t GetNumOutputs() const;
+
+ private:
+ PJRT_LoadedExecutable* executable_;
+ size_t num_outputs_;
+};
+
+// Ongoing host-to-device transfer. After the transfer is complete, it's
+// possible to fetch the device buffer.
+class PjrtHostToDeviceTransfer : protected PjrtCommon {
+ public:
+ PjrtHostToDeviceTransfer(const PJRT_Api* api, PJRT_Buffer* buffer,
+ std::unique_ptr event);
+ ~PjrtHostToDeviceTransfer();
+ // Blocks until the transfer is complete. (not really necessary as
+ // AwaitAndReleaseBuffer() waits anyway)
+ void Await();
+ // Waits for the transfer to complete and releases the ownership of the
+ // buffer.
+ std::unique_ptr AwaitAndReleaseBuffer();
+
+ private:
+ PJRT_Buffer* buffer_;
+ std::unique_ptr event_;
+};
+
+class PjrtClient : protected PjrtCommon {
+ public:
+ PjrtClient(const PJRT_Api* api, PJRT_Client* client);
+ ~PjrtClient();
+ std::unique_ptr CompileHlo(std::string_view hlo,
+ std::string_view config);
+ std::vector> GetDevices();
+ std::unique_ptr HostToDevice(
+ std::string_view buffer, PjrtType type, const std::vector& dims,
+ const PjrtDevice* device);
+
+ private:
+ PJRT_Client* client_;
+};
+
+class Pjrt : protected PjrtCommon {
+ public:
+ Pjrt(const char* library_path);
+ std::vector GetAttributes() const;
+ std::unique_ptr CreateClient();
+ std::pair ApiVersion() const;
+
+ private:
+ void Initialize();
+};
+
+} // namespace lczero
diff --git a/src/neural/xla/print_hlo.cc b/src/neural/xla/print_hlo.cc
new file mode 100644
index 0000000000..fa7d406bc3
--- /dev/null
+++ b/src/neural/xla/print_hlo.cc
@@ -0,0 +1,390 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+#include "neural/xla/print_hlo.h"
+
+namespace lczero {
+namespace {
+
+// C-escapes the given string, and appends double quotes around it.
+std::string CEscape(std::string_view str) {
+ std::string result = "\"";
+ for (char c : str) {
+ switch (c) {
+ case '\n':
+ result += "\\n";
+ break;
+ case '\t':
+ result += "\\t";
+ break;
+ case '\r':
+ result += "\\r";
+ break;
+ case '\v':
+ result += "\\v";
+ break;
+ case '\f':
+ result += "\\f";
+ break;
+ case '\\':
+ result += "\\\\";
+ break;
+ case '\"':
+ result += "\\\"";
+ break;
+ case '\'':
+ result += "\\\'";
+ break;
+ default:
+ result += c;
+ }
+ }
+ return result + "\"";
+}
+
+class HloPrettyPrinter {
+ public:
+ HloPrettyPrinter(PrettyPrintHloOptions options, std::ostream& stream)
+ : options_(options), s_(stream) {}
+
+ void PrintModule(const pblczero::HloModuleProto& module) {
+ current_module_ = &module;
+ s_ << "HloModule " << module.name();
+ if (module.has_host_program_shape()) {
+ s_ << ", entry_computation_layout=";
+ PrintProgramShape(module.host_program_shape());
+ }
+ s_ << "\n";
+
+ for (const auto& computation : module.computations()) {
+ s_ << "\n";
+ if (module.entry_computation_id() == computation.id()) s_ << "ENTRY ";
+ PrintComputation(computation);
+ }
+ current_module_ = nullptr;
+ }
+
+ private:
+ // Prints delimeted list, with value rendering function and with optional
+ // prefix and suffix. E.g. for vec=[1,2,3] and f=(x) -> print(x, x), delim=","
+ // and prefix="{" and suffix="}" it will print "{11,22,33}".
+ template
+ void PrintDelimeted(const T& vec, F print_fn, std::string_view delim,
+ std::string_view prefix = "",
+ std::string_view suffix = "") {
+ s_ << prefix;
+ for (size_t i = 0; i < vec.size(); ++i) {
+ if (i > 0) s_ << delim;
+ print_fn(vec[i]);
+ }
+ s_ << suffix;
+ }
+
+ // Returns the name of the type, which is the lowercase enum value name.
+ std::string GetTypeLiteral(pblczero::XlaShapeProto::Type type) {
+ std::string name = pblczero::XlaShapeProto::Type_Name(type);
+ for (char& c : name) c = std::tolower(c);
+ return name;
+ }
+
+ // Prints the tensor layout (e.g. {3,2,1,0} for major-to-minor layout).
+ void PrintLayout(const pblczero::XlaLayoutProto& layout) {
+ if (!options_.print_layout) return;
+ PrintDelimeted(
+ layout.minor_to_major(), [&](const auto& dim) { s_ << dim; }, ",", " {",
+ "}");
+ }
+
+ // Prints the shape of a tensor, including the type (e.g. f32[112,8,8]).
+ void PrintShape(const pblczero::XlaShapeProto& shape) {
+ if (shape.element_type() == pblczero::XlaShapeProto::TUPLE) {
+ PrintDelimeted(
+ shape.tuple_shapes(), [&](const auto& s) { PrintShape(s); }, ", ",
+ "(", ")");
+ return;
+ }
+ s_ << GetTypeLiteral(shape.element_type());
+ PrintDelimeted(
+ shape.dimensions(), [&](int64_t dim) { s_ << dim; }, ",", "[", "]");
+ if (shape.has_layout()) PrintLayout(shape.layout());
+ }
+
+ // Prints the program shape (i.e. shapes of parameters and output).
+ void PrintProgramShape(const pblczero::XlaProgramShapeProto& shape) {
+ s_ << "{(";
+ for (size_t i = 0; i < shape.parameters_size(); ++i) {
+ if (i > 0) s_ << ", ";
+ if (shape.parameter_names_size() > i &&
+ !shape.parameter_names(i).empty()) {
+ s_ << shape.parameter_names(i) << ": ";
+ }
+ PrintShape(shape.parameters(i));
+ }
+ s_ << ") -> ";
+ PrintShape(shape.result());
+ s_ << "}";
+ }
+
+ // Prints the literal (i.e. constant value).
+ void PrintLiteral(const pblczero::XlaLiteralProto& literal) {
+ // For now just print as a flat array with sometimes wrong encoding (i.e. in
+ // bf16 case).
+ auto print_array = [&](const auto& array) {
+ PrintDelimeted(
+ array,
+ [&](const auto& x) {
+ if constexpr (std::is_same_v, char> ||
+ std::is_same_v, bool>) {
+ s_ << static_cast(x);
+ } else {
+ s_ << x;
+ }
+ },
+ ",");
+ };
+ switch (literal.shape().element_type()) {
+ case pblczero::XlaShapeProto::TUPLE:
+ PrintDelimeted(
+ literal.tuple_literals(), [&](const auto& l) { PrintLiteral(l); },
+ ", ", "(", ")");
+ break;
+ case pblczero::XlaShapeProto::TOKEN:
+ s_ << "token";
+ break;
+ case pblczero::XlaShapeProto::PRED:
+ return print_array(literal.preds());
+ case pblczero::XlaShapeProto::S4:
+ return print_array(literal.s4s());
+ case pblczero::XlaShapeProto::U4:
+ return print_array(literal.u4s());
+ case pblczero::XlaShapeProto::S8:
+ return print_array(literal.s8s());
+ case pblczero::XlaShapeProto::U8:
+ return print_array(literal.u8s());
+ case pblczero::XlaShapeProto::S32:
+ return print_array(literal.s32s());
+ case pblczero::XlaShapeProto::S64:
+ return print_array(literal.s64s());
+ case pblczero::XlaShapeProto::U32:
+ return print_array(literal.u32s());
+ case pblczero::XlaShapeProto::U64:
+ return print_array(literal.u64s());
+ case pblczero::XlaShapeProto::F32:
+ return print_array(literal.f32s());
+ case pblczero::XlaShapeProto::F64:
+ return print_array(literal.f64s());
+ case pblczero::XlaShapeProto::C64:
+ return print_array(literal.c64s());
+ case pblczero::XlaShapeProto::C128:
+ return print_array(literal.c128s());
+ case pblczero::XlaShapeProto::F16:
+ return print_array(literal.f16s());
+ case pblczero::XlaShapeProto::BF16:
+ return print_array(literal.bf16s());
+ case pblczero::XlaShapeProto::U16:
+ return print_array(literal.u16s());
+ case pblczero::XlaShapeProto::S16:
+ return print_array(literal.s16s());
+ case pblczero::XlaShapeProto::F8E5M2:
+ return print_array(literal.f8e5m2s());
+ case pblczero::XlaShapeProto::F8E4M3FN:
+ return print_array(literal.f8e4m3fns());
+ case pblczero::XlaShapeProto::F8E4M3B11FNUZ:
+ return print_array(literal.f8e4m3b11fnuzs());
+ case pblczero::XlaShapeProto::F8E5M2FNUZ:
+ return print_array(literal.f8e5m2fnuzs());
+ case pblczero::XlaShapeProto::F8E4M3FNUZ:
+ return print_array(literal.f8e4m3fnuzs());
+ case pblczero::XlaShapeProto::PRIMITIVE_TYPE_INVALID:
+ s_ << "INVALID";
+ break;
+ case pblczero::XlaShapeProto::OPAQUE_TYPE:
+ s_ << "opaque";
+ break;
+ }
+ }
+
+ // Prints the operands of the given instruction. Usually operands are stored
+ // in operands() fields, but some opcodes have operands in the other fields.
+ void PrintInstructionOperands(
+ const pblczero::HloInstructionProto& instruction) {
+ s_ << "(";
+ if (instruction.opcode() == "parameter") {
+ s_ << instruction.parameter_number();
+ } else if (instruction.opcode() == "get-tuple-index") {
+ s_ << instruction.tuple_index();
+ } else if (instruction.opcode() == "constant") {
+ PrintLiteral(instruction.literal());
+ } else {
+ PrintDelimeted(
+ instruction.operand_ids(),
+ [&](int64_t id) {
+ s_ << "%" << current_computation_->instructions(id).name();
+ },
+ ", ");
+ }
+ s_ << ")";
+ }
+
+ // Prints the "window" attribute (for convolution opcodes).
+ void PrintWindow(const pblczero::XlaWindow& window) {
+ PrintDelimeted(
+ window.dimensions(), [&](const auto& d) { s_ << d.size(); }, "x",
+ "size=");
+ PrintDelimeted(
+ window.dimensions(),
+ [&](const auto& d) {
+ s_ << d.padding_low() << "_" << d.padding_high();
+ },
+ "x", " pads=");
+ }
+
+ // Prints the "dimension numbers" attribute (for dot opcodes).
+ void PrintDotDimensionNumbers(const pblczero::XlaDotDimensionNumbers& dn) {
+ PrintDelimeted(
+ dn.lhs_batch_dimensions(), [&](int64_t dim) { s_ << dim; }, ",",
+ ", lhs_batch_dims={", "}");
+ PrintDelimeted(
+ dn.rhs_batch_dimensions(), [&](int64_t dim) { s_ << dim; }, ",",
+ ", rhs_batch_dims={", "}");
+ PrintDelimeted(
+ dn.lhs_contracting_dimensions(), [&](int64_t dim) { s_ << dim; }, ",",
+ ", lhs_contracting_dims={", "}");
+ PrintDelimeted(
+ dn.rhs_contracting_dimensions(), [&](int64_t dim) { s_ << dim; }, ",",
+ ", rhs_contracting_dims={", "}");
+ }
+
+ // Prints the "dimension numbers" attribute (for convolution opcodes).
+ void PrintConvolutionDimensionNumbers(
+ const pblczero::XlaConvolutionDimensionNumbers& dn) {
+ std::string input_dims(dn.input_spatial_dimensions_size() + 2, '?');
+ std::string kernel_dims(dn.kernel_spatial_dimensions_size() + 2, '?');
+ std::string output_dims(dn.output_spatial_dimensions_size() + 2, '?');
+ input_dims[dn.input_batch_dimension()] = 'b';
+ input_dims[dn.input_feature_dimension()] = 'f';
+ kernel_dims[dn.kernel_output_feature_dimension()] = 'o';
+ kernel_dims[dn.kernel_input_feature_dimension()] = 'i';
+ output_dims[dn.output_batch_dimension()] = 'b';
+ output_dims[dn.output_feature_dimension()] = 'f';
+ for (size_t i = 0; i < dn.input_spatial_dimensions_size(); ++i) {
+ input_dims[dn.input_spatial_dimensions(i)] = '0' + i;
+ kernel_dims[dn.kernel_spatial_dimensions(i)] = '0' + i;
+ output_dims[dn.output_spatial_dimensions(i)] = '0' + i;
+ }
+ s_ << input_dims << "_" << kernel_dims << "->" << output_dims;
+ }
+
+ // Prints the attributes of the given instruction.
+ void PrintInstructionAttributes(
+ const pblczero::HloInstructionProto& instruction) {
+ if (instruction.called_computation_ids_size() > 0) {
+ PrintDelimeted(
+ instruction.called_computation_ids(),
+ [&](int64_t id) { s_ << current_module_->computations(id).name(); },
+ ",", ", calls={", "}");
+ }
+ if (instruction.has_window()) {
+ s_ << ", window={";
+ PrintWindow(instruction.window());
+ s_ << "}";
+ }
+ if (instruction.has_convolution_dimension_numbers()) {
+ s_ << ", dim_labels=";
+ PrintConvolutionDimensionNumbers(
+ instruction.convolution_dimension_numbers());
+ }
+ if (instruction.dimensions_size() > 0) {
+ PrintDelimeted(
+ instruction.dimensions(), [&](int64_t dim) { s_ << dim; }, ", ",
+ ", dimensions={", "}");
+ }
+ if (instruction.has_dot_dimension_numbers()) {
+ PrintDotDimensionNumbers(instruction.dot_dimension_numbers());
+ }
+ }
+
+ // Prints the metadata of the given instruction (source file, line, etc).
+ void PrintInstructionMetadata(
+ const pblczero::HloInstructionProto& instruction) {
+ if (instruction.has_metadata()) {
+ const auto& m = instruction.metadata();
+ s_ << ", metadata={";
+ bool first = true;
+ auto sep = [&]() -> std::ostream& {
+ if (!first) s_ << ", ";
+ first = false;
+ return s_;
+ };
+ std::vector bits;
+ if (m.has_op_type()) sep() << "op_type=" << CEscape(m.op_type());
+ if (m.has_op_name()) sep() << "op_name=" << CEscape(m.op_name());
+ if (m.has_source_file())
+ sep() << "source_file=" << CEscape(m.source_file());
+ if (m.has_source_line()) sep() << "source_line=" << m.source_line();
+ s_ << "}";
+ }
+ }
+
+ // Prints the given instruction line.
+ void PrintInstruction(const pblczero::HloInstructionProto& instruction) {
+ s_ << "%" << instruction.name() << " = ";
+ PrintShape(instruction.shape());
+ s_ << " " << instruction.opcode();
+ PrintInstructionOperands(instruction);
+ PrintInstructionAttributes(instruction);
+ PrintInstructionMetadata(instruction);
+ }
+
+ // Prints the given computation.
+ void PrintComputation(const pblczero::HloComputationProto& computation) {
+ current_computation_ = &computation;
+ s_ << computation.name() << " {\n";
+ for (const auto& instruction : computation.instructions()) {
+ s_ << " ";
+ if (computation.root_id() == instruction.id()) s_ << "ROOT ";
+ PrintInstruction(instruction);
+ s_ << "\n";
+ }
+ s_ << "}\n";
+ current_computation_ = nullptr;
+ }
+
+ PrettyPrintHloOptions options_;
+ const pblczero::HloModuleProto* current_module_ = nullptr;
+ const pblczero::HloComputationProto* current_computation_ = nullptr;
+ std::ostream& s_;
+};
+
+} // namespace
+
+void PrettyPrintHlo(const pblczero::HloModuleProto& module,
+ PrettyPrintHloOptions options, std::ostream& stream) {
+ HloPrettyPrinter(options, stream).PrintModule(module);
+}
+
+} // namespace lczero
\ No newline at end of file
diff --git a/src/neural/xla/print_hlo.h b/src/neural/xla/print_hlo.h
new file mode 100644
index 0000000000..28db989ef4
--- /dev/null
+++ b/src/neural/xla/print_hlo.h
@@ -0,0 +1,44 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+#include
+
+#include "neural/xla/hlo.pb.h"
+
+namespace lczero {
+
+struct PrettyPrintHloOptions {
+ // Print layout information (which is always major-to-minor now, e.g.
+ // {3,2,1,0}. Switched off by default as it's just noise.
+ bool print_layout = false;
+};
+
+// Pretty-prints the given HLO module to the given stream.
+void PrettyPrintHlo(const pblczero::HloModuleProto& module,
+ PrettyPrintHloOptions options, std::ostream& stream);
+
+} // namespace lczero
\ No newline at end of file
diff --git a/src/neural/xla/xla_runner.cc b/src/neural/xla/xla_runner.cc
new file mode 100644
index 0000000000..922d38f7fc
--- /dev/null
+++ b/src/neural/xla/xla_runner.cc
@@ -0,0 +1,214 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+#include "neural/xla/xla_runner.h"
+
+#include
+
+#include "utils/exception.h"
+#include "utils/logging.h"
+
+namespace lczero {
+namespace {
+
+size_t GetTypeSize(pblczero::XlaShapeProto::Type type) {
+ switch (type) {
+ case pblczero::XlaShapeProto::F32:
+ return sizeof(float);
+ case pblczero::XlaShapeProto::F64:
+ return sizeof(double);
+ case pblczero::XlaShapeProto::S32:
+ return sizeof(int32_t);
+ case pblczero::XlaShapeProto::S64:
+ return sizeof(int64_t);
+ default:
+ throw Exception("Add size for type " +
+ pblczero::XlaShapeProto::Type_Name(type));
+ }
+}
+
+std::string AsHexString(std::string_view buf) {
+ std::string result;
+ result.reserve(buf.size() * 2);
+ constexpr char hex[] = "0123456789abcdef";
+ for (unsigned char c : buf) {
+ result.push_back(hex[c >> 4]);
+ result.push_back(hex[c & 0xf]);
+ }
+ return result;
+}
+
+} // namespace
+
+std::string XlaTensor::DebugString() {
+ constexpr size_t kMaxSize = 1000;
+ constexpr size_t kSuffixSize = 200;
+ std::string result = "XlaTensor(";
+ result += "shape=[";
+ for (size_t i = 0; i < shape().size(); ++i) {
+ if (i > 0) result += ", ";
+ result += std::to_string(shape()[i]);
+ }
+ result += "], type=";
+ result += pblczero::XlaShapeProto::Type_Name(type());
+ result += ") size=" + std::to_string(size());
+ result += " data=";
+ if (size() <= kMaxSize) {
+ result += AsHexString({static_cast(data()), size()});
+ } else {
+ result += AsHexString(
+ {static_cast(data()), kMaxSize - kSuffixSize - 2});
+ result += "....";
+ result += AsHexString(
+ {static_cast(data()) + size() - kSuffixSize, kSuffixSize});
+ }
+ return result;
+}
+
+XlaRunner::XlaRunner(const char* library_path, int device)
+ : pjrt_client_(Pjrt(library_path).CreateClient()), device_(device) {
+ CERR << "Devices:";
+ devices_ = pjrt_client_->GetDevices();
+ for (const auto& device : devices_) {
+ CERR << " " << device->ToString();
+ }
+ if (devices_.empty()) {
+ throw Exception("No devices available");
+ }
+}
+
+void XlaRunner::AddModule(size_t minibatch_size,
+ const pblczero::HloModuleProto& module) {
+ pblczero::CompileOptionsProto options;
+ options.mutable_executable_build_options()->set_num_replicas(1);
+ options.mutable_executable_build_options()->set_num_partitions(1);
+ options.mutable_executable_build_options()->set_device_ordinal(device_);
+ auto executable = pjrt_client_->CompileHlo(module.OutputAsString(),
+ options.OutputAsString());
+ executables_.push_back({minibatch_size, std::move(executable)});
+ std::sort(executables_.begin(), executables_.end());
+}
+
+void XlaRunner::SetFrozenInputs(
+ const std::vector> inputs) {
+ param_idxs_.clear();
+ std::vector> transfers_;
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ const auto* input = inputs[i].get();
+ if (!input) {
+ param_idxs_.push_back(i);
+ continue;
+ }
+ transfers_.push_back(pjrt_client_->HostToDevice(
+ {static_cast(input->data()), input->size()},
+ static_cast(input->type()), input->shape(),
+ devices_.at(device_).get()));
+ }
+
+ owned_buffers_.clear();
+ buffers_.clear();
+ buffers_.resize(inputs.size());
+ size_t transfer_idx = 0;
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ if (inputs[i]) {
+ owned_buffers_.push_back(
+ transfers_[transfer_idx++]->AwaitAndReleaseBuffer());
+ buffers_[i] = owned_buffers_.back().get();
+ }
+ }
+}
+
+size_t XlaRunner::GetMaxBatchSize() const { return executables_.back().first; }
+
+std::vector> XlaRunner::ExecuteBlocking(
+ const std::vector& inputs) {
+ if (inputs.size() != 1) {
+ throw Exception("Only one input is kinda supported.");
+ }
+ // Find the smallest batch size that fits the input.
+ auto iter = std::find_if(
+ executables_.begin(), executables_.end(), [&](const auto& e) {
+ return e.first >= static_cast(inputs[0]->shape()[0]);
+ });
+ if (iter == executables_.end()) {
+ throw Exception("No executable found for batch size " +
+ std::to_string(inputs[0]->shape()[0]));
+ }
+ const size_t batch_size = iter->first;
+ // Update the shape to match the rounded up batch size. After growing, the
+ // batch size must fit within tensor buffer capacity (it's fine to have
+ // garbage in the tail of that buffer).
+ std::vector new_shape = inputs[0]->shape();
+ new_shape[0] = batch_size;
+ const size_t input_size = std::accumulate(new_shape.begin(), new_shape.end(),
+ 1, std::multiplies()) *
+ GetTypeSize(inputs[0]->type());
+ if (input_size > inputs[0]->capacity()) {
+ throw Exception("Input buffer too small");
+ }
+ // Transfer the input to the device.
+ auto input_buffer =
+ pjrt_client_
+ ->HostToDevice(
+ {static_cast(inputs[0]->data()), input_size},
+ static_cast(inputs[0]->type()), new_shape,
+ devices_[0].get())
+ ->AwaitAndReleaseBuffer();
+ // Make a copy to support multiple concurrent calls, not sure if it's needed.
+ auto input_buffers = buffers_;
+ input_buffers[param_idxs_[0]] = input_buffer.get();
+ // Execute!
+ auto outputs = iter->second->ExecuteBlocking(input_buffers);
+
+ // Now we need to transfer the outputs back to the host.
+ std::vector> result;
+ result.reserve(outputs.size());
+ std::vector output_buffers;
+ std::vector> done_events;
+ output_buffers.reserve(outputs.size());
+ done_events.reserve(outputs.size());
+ // Initialte transfers from device to host.
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ const auto& output = outputs[i];
+ output_buffers.emplace_back();
+ auto& buffer = output_buffers.back();
+ buffer.resize(output->GetSize());
+ done_events.push_back(output->DeviceToHost(&buffer[0], buffer.size()));
+ }
+ // Wait for the transfers to complete.
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ const auto& output = outputs[i];
+ done_events[i]->Await();
+ result.push_back(std::make_unique(
+ output->GetDimensions(),
+ static_cast(output->GetType()),
+ std::move(output_buffers[i])));
+ }
+ return result;
+}
+
+} // namespace lczero
\ No newline at end of file
diff --git a/src/neural/xla/xla_runner.h b/src/neural/xla/xla_runner.h
new file mode 100644
index 0000000000..85352c4142
--- /dev/null
+++ b/src/neural/xla/xla_runner.h
@@ -0,0 +1,134 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2024 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+#include "neural/xla/hlo.pb.h"
+#include "neural/xla/pjrt.h"
+
+namespace lczero {
+
+// An interface for in-host-memory tensor in XLA format.
+class XlaTensor {
+ public:
+ virtual ~XlaTensor() = default;
+ virtual const std::vector& shape() const = 0;
+ virtual const void* data() const = 0;
+ // Returns amount of valid data in bytes.
+ virtual size_t size() const = 0;
+ // Returns amount of memory that are allowed to address in bytes. This is
+ // useful when the size of the buffer has to be increased to match the
+ // supported batch size.
+ virtual size_t capacity() const = 0;
+ virtual pblczero::XlaShapeProto::Type type() const = 0;
+
+ std::string DebugString();
+};
+
+// Not-owned XLA tensor, used e.g. when ONNX buffer can be used directly, to
+// avoid unnecessary copy.
+class XlaTensorNotOwned : public XlaTensor {
+ public:
+ XlaTensorNotOwned(const std::vector& shape, std::string_view data,
+ pblczero::XlaShapeProto::Type type)
+ : shape_(&shape), data_(data), type_(type) {}
+
+ const std::vector& shape() const override { return *shape_; }
+ const void* data() const override { return data_.data(); }
+ size_t size() const override { return data_.size(); }
+ size_t capacity() const override { return data_.size(); }
+ pblczero::XlaShapeProto::Type type() const override { return type_; }
+
+ private:
+ const std::vector* shape_;
+ std::string_view data_;
+ pblczero::XlaShapeProto::Type type_;
+};
+
+// Tensor that owns data, used e.g. for XLA output.
+class XlaTensorOwned : public XlaTensor {
+ public:
+ XlaTensorOwned(const std::vector& shape,
+ pblczero::XlaShapeProto::Type type, std::string data)
+ : shape_(shape), type_(type), data_(data) {}
+
+ const std::vector& shape() const override { return shape_; }
+ const void* data() const override { return data_.data(); }
+ size_t size() const override { return data_.size(); }
+ size_t capacity() const override { return data_.size(); }
+ pblczero::XlaShapeProto::Type type() const override { return type_; }
+
+ private:
+ std::vector shape_;
+ pblczero::XlaShapeProto::Type type_;
+ std::string data_;
+};
+
+// A class that keeps several XLA executables (for different batch sizes),
+// manages common buffers among them, and chooses the right executable for a
+// batch size.
+class XlaRunner {
+ public:
+ // The library_path is the path to the PJRT library, and device indx.
+ XlaRunner(const char* library_path, int device);
+ // Compiles and adds a module for the given batch size.
+ void AddModule(size_t minibatch_size, const pblczero::HloModuleProto& module);
+ // Transfers inputs to the device and execute the executable corresponding to
+ // the batch size. Only non-frozen inputs are passed as arguments.
+ // Currnetly only single input is supported (just because we don't need more).
+ std::vector> ExecuteBlocking(
+ const std::vector& inputs);
+ // Inputs that are shared between all calls (i.e. network weights passed as
+ // parameters). These inputs are transferred to device immediately (and not
+ // for each inference).
+ void SetFrozenInputs(const std::vector> inputs);
+ // Maximum supported batch size. It's expected that the capacity (not size) of
+ // the input tensors would be able to fit this size.
+ size_t GetMaxBatchSize() const;
+
+ private:
+ std::unique_ptr pjrt_client_;
+ std::vector> devices_;
+ // Compiled executables per batch size.
+ std::vector>> executables_;
+ // Frozen inputs, in no particular order, kept for ownership.
+ std::vector> owned_buffers_;
+ // Vector of pointers to all input buffers, that is passed to PJRT. Frozen
+ // parameters (constants) are pre-filled in SetFrozenInputs(), and non-frozen
+ // inputs (input planes) are created and filled in every request.
+ std::vector buffers_;
+ std::vector param_idxs_;
+ int device_;
+};
+
+} // namespace lczero
\ No newline at end of file
diff --git a/third_party/pjrt_c_api.h b/third_party/pjrt_c_api.h
new file mode 100644
index 0000000000..f1ab16f1a6
--- /dev/null
+++ b/third_party/pjrt_c_api.h
@@ -0,0 +1,2175 @@
+/* Copyright 2022 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_PJRT_C_PJRT_C_API_H_
+#define XLA_PJRT_C_PJRT_C_API_H_
+
+#include
+#include
+#include
+
+#define PJRT_STRUCT_SIZE(struct_type, last_field) \
+ offsetof(struct_type, last_field) + sizeof(((struct_type*)0)->last_field)
+
+#define PJRT_DEFINE_STRUCT_TRAITS(sname, last_field) \
+ typedef struct sname sname; \
+ enum { sname##_STRUCT_SIZE = PJRT_STRUCT_SIZE(sname, last_field) }
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// --------------------------------- Version -----------------------------------
+
+// Incremented when an ABI-incompatible change is made to the interface.
+// Changes include:
+// * Deleting a method or argument
+// * Changing the type of an argument
+// * Rearranging fields in the PJRT_Api or argument structs
+#define PJRT_API_MAJOR 0
+
+// Incremented when the interface is updated in a way that is potentially
+// ABI-compatible with older versions, if supported by the caller and/or
+// implementation.
+//
+// Callers can implement forwards compatibility by using PJRT_Api_Version to
+// check if the implementation is aware of newer interface additions.
+//
+// Implementations can implement backwards compatibility by using the
+// `struct_size` fields to detect how many struct fields the caller is aware of.
+//
+// Changes include:
+// * Adding a new field to the PJRT_Api or argument structs
+// * Renaming a method or argument (doesn't affect ABI)
+#define PJRT_API_MINOR 40
+
+// The plugin should set the major_version and minor_version of
+// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
+// this header that the implementation was compiled with.
+struct PJRT_Api_Version {
+ size_t struct_size;
+ void* priv;
+ int major_version; // out
+ int minor_version; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Api_Version, minor_version);
+
+// ---------------------------------- Errors -----------------------------------
+
+// PJRT C API methods generally return a PJRT_Error*, which is nullptr if there
+// is no error and set if there is. The implementation allocates any returned
+// PJRT_Errors, but the caller is always responsible for freeing them via
+// PJRT_Error_Destroy.
+
+typedef struct PJRT_Error PJRT_Error;
+
+struct PJRT_Error_Destroy_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Error* error;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Error_Destroy_Args, error);
+
+// Frees `error`. `error` can be nullptr.
+typedef void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args);
+
+struct PJRT_Error_Message_Args {
+ size_t struct_size;
+ void* priv;
+ const PJRT_Error* error;
+ // Has the lifetime of `error`.
+ const char* message; // out
+ size_t message_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Error_Message_Args, message_size);
+
+// Gets the human-readable reason for `error`. `message` has the lifetime of
+// `error`.
+typedef void PJRT_Error_Message(PJRT_Error_Message_Args* args);
+
+// Codes are based on https://abseil.io/docs/cpp/guides/status-codes
+typedef enum {
+ PJRT_Error_Code_CANCELLED = 1,
+ PJRT_Error_Code_UNKNOWN = 2,
+ PJRT_Error_Code_INVALID_ARGUMENT = 3,
+ PJRT_Error_Code_DEADLINE_EXCEEDED = 4,
+ PJRT_Error_Code_NOT_FOUND = 5,
+ PJRT_Error_Code_ALREADY_EXISTS = 6,
+ PJRT_Error_Code_PERMISSION_DENIED = 7,
+ PJRT_Error_Code_RESOURCE_EXHAUSTED = 8,
+ PJRT_Error_Code_FAILED_PRECONDITION = 9,
+ PJRT_Error_Code_ABORTED = 10,
+ PJRT_Error_Code_OUT_OF_RANGE = 11,
+ PJRT_Error_Code_UNIMPLEMENTED = 12,
+ PJRT_Error_Code_INTERNAL = 13,
+ PJRT_Error_Code_UNAVAILABLE = 14,
+ PJRT_Error_Code_DATA_LOSS = 15,
+ PJRT_Error_Code_UNAUTHENTICATED = 16
+} PJRT_Error_Code;
+
+struct PJRT_Error_GetCode_Args {
+ size_t struct_size;
+ void* priv;
+ const PJRT_Error* error;
+ PJRT_Error_Code code; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Error_GetCode_Args, code);
+
+typedef PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args);
+
+// Function for PJRT implementation to pass to callback functions provided by
+// caller so the callback can create a PJRT_Error* on error (to return to the
+// implementation). `message` is only required to live for the
+// PJRT_CallbackError call, i.e. the PJRT_CallbackError implementation must copy
+// `message` into the PJRT_Error.
+typedef PJRT_Error* (*PJRT_CallbackError)(PJRT_Error_Code code,
+ const char* message,
+ size_t message_size);
+
+// ---------------------------- Named Values -----------------------------------
+
+typedef enum {
+ PJRT_NamedValue_kString = 0,
+ PJRT_NamedValue_kInt64,
+ PJRT_NamedValue_kInt64List,
+ PJRT_NamedValue_kFloat,
+ PJRT_NamedValue_kBool,
+} PJRT_NamedValue_Type;
+
+// Named value for key-value pairs.
+struct PJRT_NamedValue {
+ size_t struct_size;
+ void* priv;
+ const char* name;
+ size_t name_size;
+ PJRT_NamedValue_Type type;
+ union {
+ const char* string_value;
+ int64_t int64_value;
+ const int64_t* int64_array_value;
+ float float_value;
+ bool bool_value;
+ };
+ // `value_size` is the number of elements for array/string and 1 for scalar
+ // values.
+ size_t value_size;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_NamedValue, value_size);
+
+// ---------------------------------- Plugin -----------------------------------
+
+struct PJRT_Plugin_Initialize_Args {
+ size_t struct_size;
+ void* priv;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Initialize_Args, priv);
+
+// One-time plugin setup. Must be called before any other functions are called.
+typedef PJRT_Error* PJRT_Plugin_Initialize(PJRT_Plugin_Initialize_Args* args);
+
+struct PJRT_Plugin_Attributes_Args {
+ size_t struct_size;
+ void* priv;
+ // Returned attributes have the lifetime of the process.
+ const PJRT_NamedValue* attributes; // out
+ size_t num_attributes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Attributes_Args, attributes);
+
+// Returns an array of plugin attributes which are key-value pairs. One example
+// attribute is the minimum supported StableHLO version.
+// TODO(b/280349977): standardize the list of attributes.
+typedef PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args);
+
+// ---------------------------------- Events -----------------------------------
+
+// Represents a notifying event that is returned by PJRT APIs that enqueue
+// asynchronous work, informing callers when the work is complete and reporting
+// a value of type `PJRT_Error*` or `nullptr` as error status.
+//
+// Callers are always responsible for freeing `PJRT_Event`s by calling
+// `PJRT_Event_Destroy`.
+typedef struct PJRT_Event PJRT_Event;
+
+struct PJRT_Event_Destroy_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Event* event;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Destroy_Args, event);
+
+// Frees `event`. `event` can be `nullptr`.
+typedef PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args);
+
+struct PJRT_Event_IsReady_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Event* event;
+ bool is_ready; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_IsReady_Args, is_ready);
+
+// Returns true if this PJRT_Event has completed, including if an error has
+// occurred.
+typedef PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args);
+
+struct PJRT_Event_Error_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Event* event;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Error_Args, event);
+
+// Should only be called if PJRT_Event_IsReady returns true.
+// Returns `nullptr` if there is no error.
+// The returned error should be freed with `PJRT_Error_Destroy`.
+//
+// If `PJRT_Event_Await` has been called, this will return a pointer to an
+// identical error status as that call, as will subsequent calls to
+// `PJRT_Event_Error`. However, each of these `PJRT_Error *` pointers are
+// independent of `PJRT_Error *`s returned by other function calls, so they must
+// each be freed separately using `PJRT_Error_Destroy`.
+typedef PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args);
+
+struct PJRT_Event_Await_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Event* event;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Await_Args, event);
+
+// Blocks the calling thread until `event` is ready, then returns the error
+// status (with `nullptr` indicating no error). The returned status should be
+// freed with `PJRT_Error_Destroy`.
+typedef PJRT_Error* PJRT_Event_Await(PJRT_Event_Await_Args* args);
+
+// A callback to be performed once an event is ready. It will be called on the
+// event's error state and a pointer to an object of the caller's choice.
+// Ownership of `error` is passed to the callback. The callback must destroy
+// `error` via `PJRT_Error_Destroy`. The caller retains ownership of `user_arg`.
+typedef void (*PJRT_Event_OnReadyCallback)(PJRT_Error* error, void* user_arg);
+
+struct PJRT_Event_OnReady_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Event* event;
+ PJRT_Event_OnReadyCallback callback;
+ // `user_arg` allows `callback` to be called with arbitrary arguments (e.g.
+ // via pointers in a struct cast to void*).
+ void* user_arg;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_OnReady_Args, user_arg);
+
+// Registers `callback` to be called once `event` is ready, with `event`'s
+// error status and a pointer to an object of the caller's choice as arguments.
+typedef PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args);
+
+// ---------------------------------- Client -----------------------------------
+
+typedef struct PJRT_Client PJRT_Client;
+typedef struct PJRT_Device PJRT_Device;
+typedef struct PJRT_Memory PJRT_Memory;
+typedef struct PJRT_DeviceDescription PJRT_DeviceDescription;
+typedef struct PJRT_TopologyDescription PJRT_TopologyDescription;
+typedef struct PJRT_Executable PJRT_Executable;
+typedef struct PJRT_LoadedExecutable PJRT_LoadedExecutable;
+typedef struct PJRT_Buffer PJRT_Buffer;
+
+// The caller of PJRT_Client_Create can optionally provide a key-value store
+// accessible across nodes and/or processes. KV store access may be necessary to
+// create some multi-node/multi-process clients. The caller can provide the two
+// callbacks below to access the key-value store.
+
+// A callback to delete the value returned by PJRT_KeyValueGetCallback.
+typedef void (*PJRT_KeyValueGetCallback_ValueDeleter)(char* value);
+
+struct PJRT_KeyValueGetCallback_Args {
+ size_t struct_size;
+ void* priv;
+ const char* key;
+ size_t key_size;
+ int timeout_in_ms;
+ PJRT_CallbackError* callback_error;
+ void* user_arg;
+ char* value; // out
+ size_t value_size; // out
+ // The caller needs to set a PJRT_KeyValueGetCallback_ValueDeleter to delete
+ // the value returned by PJRT_KeyValueGetCallback. The implementation is
+ // responsible for copying `value` and then calling value_deleter_callback.
+ PJRT_KeyValueGetCallback_ValueDeleter value_deleter_callback; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueGetCallback_Args,
+ value_deleter_callback);
+
+// Requirements for PJRT_KeyValueGetCallback implementation: (1) Thread-safe.
+// (2) The caller that provides the two callbacks is responsible for avoiding
+// key collisions between different users of key-value store (i.e. between
+// different plugins, but not between different nodes in one plugin). (3)
+// Blocking.
+typedef PJRT_Error* (*PJRT_KeyValueGetCallback)(
+ PJRT_KeyValueGetCallback_Args* args);
+
+struct PJRT_KeyValuePutCallback_Args {
+ size_t struct_size;
+ void* priv;
+ const char* key;
+ size_t key_size;
+ // Only needs to stay alive for the duration of the PJRT_KeyValuePutCallback
+ // call.
+ const char* value;
+ size_t value_size;
+ PJRT_CallbackError* callback_error;
+ void* user_arg;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValuePutCallback_Args, user_arg);
+
+// Requirements for PJRT_KeyValuePutCallback implementation: (1) Thread-safe.
+// (2) The caller that provides the two callbacks is responsible for avoiding
+// key collisions between different users of key-value store (i.e. between
+// different plugins, but not between different nodes in one plugin).
+typedef PJRT_Error* (*PJRT_KeyValuePutCallback)(
+ PJRT_KeyValuePutCallback_Args* args);
+
+struct PJRT_Client_Create_Args {
+ size_t struct_size;
+ void* priv;
+ // Extra platform-specific options to create a client.
+ const PJRT_NamedValue* create_options;
+ size_t num_options;
+ // Key-value get/put callback provided by the caller of PJRT_Client_Create.
+ // PJRT client can use these callbacks to share information between
+ // processes/nodes.
+ PJRT_KeyValueGetCallback kv_get_callback;
+ // Will be passed to `kv_get_callback` as `user_arg` argument.
+ void* kv_get_user_arg;
+ PJRT_KeyValuePutCallback kv_put_callback;
+ // Will be passed to `kv_put_callback` as `user_arg` argument.
+ void* kv_put_user_arg;
+
+ PJRT_Client* client; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, client);
+
+// Creates and initializes a new PJRT_Client and returns in `client`.
+typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args);
+
+struct PJRT_Client_Destroy_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Destroy_Args, client);
+
+// Shuts down and frees `client`. `client` can be nullptr.
+typedef PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args);
+
+struct PJRT_Client_PlatformName_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ // `platform_name` has the same lifetime as `client`. It is owned by `client`.
+ const char* platform_name; // out
+ size_t platform_name_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_PlatformName_Args, platform_name_size);
+
+// Returns a string that identifies the platform (e.g. "cpu", "gpu", "tpu").
+typedef PJRT_Error* PJRT_Client_PlatformName(
+ PJRT_Client_PlatformName_Args* args);
+
+struct PJRT_Client_ProcessIndex_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ int process_index; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_ProcessIndex_Args, process_index);
+
+// Return the process index of this client. Always 0 in single-process
+// settings.
+typedef PJRT_Error* PJRT_Client_ProcessIndex(
+ PJRT_Client_ProcessIndex_Args* args);
+
+struct PJRT_Client_PlatformVersion_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ // `platform_version` has the same lifetime as `client`. It's owned by
+ // `client`.
+ const char* platform_version; // out
+ size_t platform_version_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_PlatformVersion_Args,
+ platform_version_size);
+
+// Returns a string containing human-readable, platform-specific version info
+// (e.g. the CUDA version on GPU or libtpu version on Cloud TPU).
+typedef PJRT_Error* PJRT_Client_PlatformVersion(
+ PJRT_Client_PlatformVersion_Args* args);
+
+struct PJRT_Client_TopologyDescription_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ // Is owned by and has the same lifetime as `client`.
+ PJRT_TopologyDescription* topology; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_TopologyDescription_Args, topology);
+
+// Returns the topology description of the runtime topology. The returned
+// topology is owned by the client and should not be deleted by the caller.
+typedef PJRT_Error* PJRT_Client_TopologyDescription(
+ PJRT_Client_TopologyDescription_Args* args);
+
+struct PJRT_Client_Devices_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ PJRT_Device* const* devices; // out
+ size_t num_devices; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Devices_Args, num_devices);
+
+// Returns a list of all devices visible to the runtime, including addressable
+// and non-addressable devices.
+typedef PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args);
+
+struct PJRT_Client_AddressableDevices_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ PJRT_Device* const* addressable_devices; // out
+ size_t num_addressable_devices; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_AddressableDevices_Args,
+ num_addressable_devices);
+
+// Returns a list of devices that are addressable from the client.
+// Addressable devices are those that the client can issue commands to.
+// All devices are addressable in a single-process environment.
+typedef PJRT_Error* PJRT_Client_AddressableDevices(
+ PJRT_Client_AddressableDevices_Args* args);
+
+struct PJRT_Client_LookupDevice_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ int id;
+ // `device` has the same lifetime as `client`. It is owned by `client`.
+ PJRT_Device* device; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_LookupDevice_Args, device);
+
+// Returns a PJRT_Device* with the specified ID as returned by
+// PJRT_DeviceDescription_Id.
+typedef PJRT_Error* PJRT_Client_LookupDevice(
+ PJRT_Client_LookupDevice_Args* args);
+
+struct PJRT_Client_LookupAddressableDevice_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ int local_hardware_id;
+ // `addressable_device` has the same lifetime as `client`. It is owned by
+ // `client`.
+ PJRT_Device* addressable_device; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_LookupAddressableDevice_Args,
+ addressable_device);
+
+// Returns an addressable PJRT_Device* with the specified ID as returned by
+// PJRT_DeviceDescription_LocalHardwareId.
+typedef PJRT_Error* PJRT_Client_LookupAddressableDevice(
+ PJRT_Client_LookupAddressableDevice_Args* args);
+
+struct PJRT_Client_AddressableMemories_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ PJRT_Memory* const* addressable_memories; // out
+ size_t num_addressable_memories; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_AddressableMemories_Args,
+ num_addressable_memories);
+
+// Returns a list of memories that are addressable from the client. Addressable
+// memories are those that the client can directly transfer data to and from.
+// All memories are addressable in a single-process environment.
+typedef PJRT_Error* PJRT_Client_AddressableMemories(
+ PJRT_Client_AddressableMemories_Args* args);
+
+struct PJRT_Program {
+ size_t struct_size;
+ void* priv;
+ // Serialized code in the specified format below.
+ // String is owned by the caller.
+ char* code; // in/out depending on usage
+ size_t code_size;
+ // Supported formats are:
+ // "hlo": code string takes serialized HloModuleProto.
+ // "hlo_with_config": code string takes serialized HloModuleProtoWithConfig.
+ // "mlir": code string takes MLIR module bytecode (or string).
+ // Ownership of `format` varies across API functions.
+ const char* format;
+ size_t format_size;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Program, format_size);
+
+struct PJRT_Client_Compile_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ // Only needs to stay alive for the duration of the Compile call.
+ // `program->format` and `program->format_size` are owned by the caller.
+ const PJRT_Program* program;
+ // TODO(b/240560013): consider putting some of option fields in priv.
+ // Serialized CompileOptionsProto
+ // (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/compile_options.proto)
+ const char* compile_options;
+ size_t compile_options_size;
+ PJRT_LoadedExecutable* executable; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Compile_Args, executable);
+
+// Compiles a program in specified format (such as MLIR or HLO) with given
+// `options`.
+typedef PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args);
+
+struct PJRT_Client_DefaultDeviceAssignment_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ int num_replicas;
+ int num_partitions;
+ // Must be greater than or equal to `num_replicas * num_partitions`
+ size_t default_assignment_size;
+ // Points to an array of size `default_assignment_size`.
+ // This API writes `num_replicas * num_partitions` ints within that buffer.
+ // The caller retains ownership of this memory.
+ int* default_assignment; // pointer to array in; values written as out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_DefaultDeviceAssignment_Args,
+ default_assignment);
+
+typedef PJRT_Error* PJRT_Client_DefaultDeviceAssignment(
+ PJRT_Client_DefaultDeviceAssignment_Args* args);
+
+typedef enum {
+ // Invalid primitive type to serve as default.
+ PJRT_Buffer_Type_INVALID,
+
+ // Predicates are two-state booleans.
+ PJRT_Buffer_Type_PRED,
+
+ // Signed integral values of fixed width.
+ PJRT_Buffer_Type_S8,
+ PJRT_Buffer_Type_S16,
+ PJRT_Buffer_Type_S32,
+ PJRT_Buffer_Type_S64,
+
+ // Unsigned integral values of fixed width.
+ PJRT_Buffer_Type_U8,
+ PJRT_Buffer_Type_U16,
+ PJRT_Buffer_Type_U32,
+ PJRT_Buffer_Type_U64,
+
+ // Floating-point values of fixed width.
+ PJRT_Buffer_Type_F16,
+ PJRT_Buffer_Type_F32,
+ PJRT_Buffer_Type_F64,
+
+ // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
+ // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
+ // and 7 bits for the mantissa.
+ PJRT_Buffer_Type_BF16,
+
+ // Complex values of fixed width.
+ //
+ // Paired F32 (real, imag), as in std::complex.
+ PJRT_Buffer_Type_C64,
+ // Paired F64 (real, imag), as in std::complex.
+ PJRT_Buffer_Type_C128,
+
+ // Truncated 8 bit floating-point formats.
+ PJRT_Buffer_Type_F8E5M2,
+ PJRT_Buffer_Type_F8E4M3FN,
+ PJRT_Buffer_Type_F8E4M3B11FNUZ,
+ PJRT_Buffer_Type_F8E5M2FNUZ,
+ PJRT_Buffer_Type_F8E4M3FNUZ,
+
+ // 4-bit integer types
+ PJRT_Buffer_Type_S4,
+ PJRT_Buffer_Type_U4,
+} PJRT_Buffer_Type;
+
+typedef enum {
+ // The runtime may not hold references to `data` after the call to
+ // `PJRT_Client_BufferFromHostBuffer` completes. The caller promises that
+ // `data` is immutable and will not be freed only for the duration of the
+ // PJRT_Client_BufferFromHostBuffer call.
+ PJRT_HostBufferSemantics_kImmutableOnlyDuringCall,
+
+ // The runtime may hold onto `data` after the call to
+ // `PJRT_Client_BufferFromHostBuffer`
+ // returns while the runtime completes a transfer to the device. The caller
+ // promises not to mutate or free `data` until the transfer completes, at
+ // which point `done_with_host_buffer` will be triggered.
+ PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes,
+
+ // The PjRtBuffer may alias `data` internally and the runtime may use the
+ // `data` contents as long as the buffer is alive. The caller promises to
+ // keep `data` alive and not to mutate its contents as long as the buffer is
+ // alive; to notify the caller that the buffer may be freed, the runtime
+ // will call `done_with_host_buffer` when the PjRtBuffer is freed.
+ PJRT_HostBufferSemantics_kZeroCopy,
+} PJRT_HostBufferSemantics;
+
+typedef enum {
+ PJRT_Buffer_MemoryLayout_Type_Tiled = 0,
+ PJRT_Buffer_MemoryLayout_Type_Strides,
+} PJRT_Buffer_MemoryLayout_Type;
+
+struct PJRT_Buffer_MemoryLayout_Tiled {
+ size_t struct_size;
+ void* priv;
+ // A map from physical dimension numbers to logical dimension numbers.
+ // The first element is the most minor physical dimension (fastest varying
+ // index) and the last the most major (slowest varying index). The contents of
+ // the vector are the indices of the *logical* dimensions in the shape. Must
+ // be the same size as the number of dimensions of the buffer.
+ const int64_t* minor_to_major;
+ size_t minor_to_major_size;
+ // A concatenated list of tile dimensions.
+ const int64_t* tile_dims;
+ // The list of tile dimension sizes. The size of this list is `num_tiles`.
+ const size_t* tile_dim_sizes;
+ size_t num_tiles;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout_Tiled, num_tiles);
+
+struct PJRT_Buffer_MemoryLayout_Strides {
+ size_t struct_size;
+ void* priv;
+ // Number of bytes to traverse per dimension. Must be the same size as
+ // the number of dimensions of the data. Caution: `byte_strides` are allowed
+ // to be negative, in which case data may need to point to the interior of
+ // the buffer, not necessarily its start.
+ const int64_t* byte_strides;
+ size_t num_byte_strides;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout_Strides, num_byte_strides);
+
+// Describe the memory layout. It can be (1) a list of minor-to-major order and
+// optional tilings (each tile is a list of dimensions), or (2) a list of
+// strides.
+struct PJRT_Buffer_MemoryLayout {
+ size_t struct_size;
+ void* priv;
+ union {
+ PJRT_Buffer_MemoryLayout_Tiled tiled;
+ PJRT_Buffer_MemoryLayout_Strides strides;
+ };
+ PJRT_Buffer_MemoryLayout_Type type;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout, type);
+
+struct PJRT_Client_BufferFromHostBuffer_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ // Pointer to the host buffer
+ const void* data;
+ // The type of the `data`, and the type of the resulting output `buffer`
+ PJRT_Buffer_Type type;
+ // The array dimensions of `data`.
+ const int64_t* dims;
+ size_t num_dims;
+
+ // Number of bytes to traverse per dimension of the input data. Must be the
+ // same size as `dims`, or empty. If empty, the array is assumed to have a
+ // dense layout with dimensions in major-to-minor order
+ // Caution: `byte_strides` are allowed to be negative, in which case `data`
+ // may need to point to the interior of the buffer, not necessarily its start.
+ const int64_t* byte_strides;
+ size_t num_byte_strides;
+
+ PJRT_HostBufferSemantics host_buffer_semantics;
+
+ // Device to copy host data to.
+ PJRT_Device* device;
+
+ // If nullptr, host data will be copied to `device`, otherwise we copy data to
+ // `memory`.
+ PJRT_Memory* memory;
+
+ // The caller is responsible to keep the data (tiled or strides) in the
+ // device_layout alive during the call. If nullptr, the device layout is
+ // assumed to be a dense layout with dimensions in major-to-minor order.
+ PJRT_Buffer_MemoryLayout* device_layout;
+
+ // Event indicating when it's safe to free `data`. The caller is responsible
+ // for calling PJRT_Event_Destroy.
+ PJRT_Event* done_with_host_buffer; // out
+
+ // Output device buffer. The caller is responsible for calling
+ // PJRT_Buffer_Destroy.
+ PJRT_Buffer* buffer; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_BufferFromHostBuffer_Args, buffer);
+
+// Asynchronously copies a buffer stored on host to device memory.
+typedef PJRT_Error* PJRT_Client_BufferFromHostBuffer(
+ PJRT_Client_BufferFromHostBuffer_Args* args);
+
+struct PJRT_Client_CreateViewOfDeviceBuffer_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ // A pointer to a non-owned device buffer. A PJRT_Buffer that is a non-owned
+ // view of this device buffer will be created.
+ void* device_buffer_ptr;
+ const int64_t* dims;
+ size_t num_dims;
+ PJRT_Buffer_Type element_type;
+ PJRT_Buffer_MemoryLayout* layout;
+ // The device that `device_buffer_ptr` is on.
+ PJRT_Device* device;
+ // A callback to be performed when the PJRT_Buffer is done with the on-device
+ // buffer. This callback is optional and can be a nullptr.
+ void (*on_delete_callback)(void* device_buffer_ptr, void* user_arg);
+ // `on_delete_callback_arg` will be passed to `on_delete_callback` as
+ // `user_arg` argument.
+ void* on_delete_callback_arg;
+ // A platform-specific stream handle that should contain the work or events
+ // needed to materialize the on-device buffer. It is optional and can be
+ // casted from a nullptr. PJRT_Client_CreateViewOfDeviceBuffer_Args will
+ // append an event to `stream` that indicates when the returned buffer is
+ // ready to use. This is intended to support dlpack on GPU and is not expected
+ // to be supported on all hardware platforms.
+ intptr_t stream;
+ PJRT_Buffer* buffer; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateViewOfDeviceBuffer_Args, buffer);
+
+// Creates a PJRT buffer that is a non-owned view of an on-device buffer
+// (typically allocated by another library). The buffer may be mutated,
+// for example, if the buffer is donated to an Execute operation. This method is
+// not required on all hardware platforms.
+typedef PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer(
+ PJRT_Client_CreateViewOfDeviceBuffer_Args* args);
+
+// -------------------------- Device Descriptions ------------------------------
+
+// Device descriptions may be associated with an actual device
+// (via PJRT_Device_GetDescription), but they can also be used to describe a
+// device that isn't currently available to the plugin. This is useful for
+// compiling executables without hardware available, which can then be
+// serialized and written somewhere durable, and then loaded and run on actual
+// hardware later.
+
+struct PJRT_DeviceDescription_Id_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_DeviceDescription* device_description;
+ int id; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_Id_Args, id);
+
+// The ID of this device. IDs are unique among devices of this type
+// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
+// hosts' devices.
+typedef PJRT_Error* PJRT_DeviceDescription_Id(
+ PJRT_DeviceDescription_Id_Args* args);
+
+struct PJRT_DeviceDescription_ProcessIndex_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_DeviceDescription* device_description;
+ int process_index; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_ProcessIndex_Args,
+ process_index);
+
+// The index of the process that this device belongs to, i.e. is addressable
+// from. This is not always identical to PJRT_Client_ProcessIndex in a
+// multi-process setting, where each client can see devices from all
+// processes, but only a subset of them are addressable and have the same
+// process_index as the client.
+typedef PJRT_Error* PJRT_DeviceDescription_ProcessIndex(
+ PJRT_DeviceDescription_ProcessIndex_Args* args);
+
+struct PJRT_DeviceDescription_Attributes_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_DeviceDescription* device_description;
+ size_t num_attributes; // out
+ const PJRT_NamedValue* attributes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_Attributes_Args, attributes);
+
+// Returns an array of device specific attributes with attribute name, value
+// and value type.
+typedef PJRT_Error* PJRT_DeviceDescription_Attributes(
+ PJRT_DeviceDescription_Attributes_Args* args);
+
+struct PJRT_DeviceDescription_Kind_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_DeviceDescription* device_description;
+ // `device_kind` string is owned by `device` and has same lifetime as
+ // `device`.
+ const char* device_kind; // out
+ size_t device_kind_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_Kind_Args, device_kind_size);
+
+// A vendor-dependent string that uniquely identifies the kind of device,
+// e.g., "Tesla V100-SXM2-16GB".
+typedef PJRT_Error* PJRT_DeviceDescription_Kind(
+ PJRT_DeviceDescription_Kind_Args* args);
+
+struct PJRT_DeviceDescription_DebugString_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_DeviceDescription* device_description;
+ const char* debug_string; // out
+ size_t debug_string_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_DebugString_Args,
+ debug_string_size);
+
+// Debug string suitable for logging when errors occur. Should be verbose
+// enough to describe the current device unambiguously.
+typedef PJRT_Error* PJRT_DeviceDescription_DebugString(
+ PJRT_DeviceDescription_DebugString_Args* args);
+
+struct PJRT_DeviceDescription_ToString_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_DeviceDescription* device_description;
+ const char* to_string; // out
+ size_t to_string_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_ToString_Args, to_string_size);
+
+// Debug string suitable for reading by end users, should be reasonably terse,
+// for example: "CpuDevice(id=0)".
+typedef PJRT_Error* PJRT_DeviceDescription_ToString(
+ PJRT_DeviceDescription_ToString_Args* args);
+
+// --------------------------------- Devices -----------------------------------
+
+struct PJRT_Device_GetDescription_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Device* device;
+ PJRT_DeviceDescription* device_description; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_GetDescription_Args, device_description);
+
+// Fetch the DeviceDescription associated with this device.
+typedef PJRT_Error* PJRT_Device_GetDescription(
+ PJRT_Device_GetDescription_Args* args);
+
+struct PJRT_Device_IsAddressable_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Device* device;
+ bool is_addressable; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_IsAddressable_Args, is_addressable);
+
+// Whether client can issue command to this device.
+typedef PJRT_Error* PJRT_Device_IsAddressable(
+ PJRT_Device_IsAddressable_Args* args);
+
+struct PJRT_Device_LocalHardwareId_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Device* device;
+ int local_hardware_id; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_LocalHardwareId_Args, local_hardware_id);
+
+// Opaque hardware ID, e.g., the CUDA device number. In general, not guaranteed
+// to be dense, and -1 if undefined.
+typedef PJRT_Error* PJRT_Device_LocalHardwareId(
+ PJRT_Device_LocalHardwareId_Args* args);
+
+struct PJRT_Device_AddressableMemories_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Device* device;
+ // Has the lifetime of `device`.
+ PJRT_Memory* const* memories; // out
+ size_t num_memories; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_AddressableMemories_Args, memories);
+
+// Returns the memories that a device can address.
+typedef PJRT_Error* PJRT_Device_AddressableMemories(
+ PJRT_Device_AddressableMemories_Args* args);
+
+struct PJRT_Device_DefaultMemory_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Device* device;
+ // `memory` has the same lifetime as `device`.
+ PJRT_Memory* memory; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_DefaultMemory_Args, memory);
+
+// Returns the default memory of a device, i.e. which memory data processed by
+// this device should be stored in by default.
+typedef PJRT_Error* PJRT_Device_DefaultMemory(
+ PJRT_Device_DefaultMemory_Args* args);
+
+struct PJRT_Device_MemoryStats_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Device* device;
+
+ // Number of bytes in use.
+ int64_t bytes_in_use; // out
+
+ // The peak bytes in use.
+ int64_t peak_bytes_in_use; // out
+ bool peak_bytes_in_use_is_set; // out
+ // Number of allocations.
+ int64_t num_allocs; // out
+ bool num_allocs_is_set; // out
+ // The largest single allocation seen.
+ int64_t largest_alloc_size; // out
+ bool largest_alloc_size_is_set; // out
+ // The upper limit of user-allocatable device memory in bytes.
+ int64_t bytes_limit; // out
+ bool bytes_limit_is_set; // out
+
+ // Number of bytes reserved.
+ int64_t bytes_reserved; // out
+ bool bytes_reserved_is_set; // out
+ // The peak number of bytes reserved.
+ int64_t peak_bytes_reserved; // out
+ bool peak_bytes_reserved_is_set; // out
+ // The upper limit on the number bytes of reservable memory.
+ int64_t bytes_reservable_limit; // out
+ bool bytes_reservable_limit_is_set; // out
+
+ // Largest free block size in bytes.
+ int64_t largest_free_block_bytes; // out
+ bool largest_free_block_bytes_is_set; // out
+
+ // Number of bytes of memory held by the allocator. This may be higher than
+ // bytes_in_use if the allocator holds a pool of memory (e.g. BFCAllocator).
+ int64_t pool_bytes; // out
+ bool pool_bytes_is_set; // out
+ int64_t peak_pool_bytes; // out
+ bool peak_pool_bytes_is_set; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_MemoryStats_Args, peak_pool_bytes_is_set);
+
+// Device memory/allocator statistics. All returned stats except `bytes_in_use`
+// are optional and may not be returned by all platforms. Implementations may
+// also return PJRT_Error_Code_UNIMPLEMENTED. Intended for diagnostic purposes.
+typedef PJRT_Error* PJRT_Device_MemoryStats(PJRT_Device_MemoryStats_Args* args);
+
+//-------------------------------- Memory --------------------------------------
+
+struct PJRT_Memory_Id_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Memory* memory;
+ int id; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_Id_Args, id);
+
+// The ID of this memory. IDs are unique among memories of this type.
+typedef PJRT_Error* PJRT_Memory_Id(PJRT_Memory_Id_Args* args);
+
+struct PJRT_Memory_Kind_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Memory* memory;
+ // `memory_kind` has same lifetime as `memory`.
+ const char* memory_kind; // out
+ size_t memory_kind_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_Kind_Args, memory_kind_size);
+
+// A platform-dependent string that uniquely identifies the kind of the memory.
+typedef PJRT_Error* PJRT_Memory_Kind(PJRT_Memory_Kind_Args* args);
+
+struct PJRT_Memory_DebugString_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Memory* memory;
+ const char* debug_string; // out
+ size_t debug_string_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_DebugString_Args, debug_string_size);
+
+// Debug string suitable for logging when errors occur. Should be verbose
+// enough to describe the current memory unambiguously.
+typedef PJRT_Error* PJRT_Memory_DebugString(PJRT_Memory_DebugString_Args* args);
+
+struct PJRT_Memory_ToString_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Memory* memory;
+ const char* to_string; // out
+ size_t to_string_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_ToString_Args, to_string_size);
+
+// Debug string suitable for reading by end users, should be reasonably terse.
+typedef PJRT_Error* PJRT_Memory_ToString(PJRT_Memory_ToString_Args* args);
+
+struct PJRT_Memory_AddressableByDevices_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Memory* memory;
+ PJRT_Device* const* devices; // out
+ size_t num_devices; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_AddressableByDevices_Args, num_devices);
+
+// Returns the devices that can address this memory.
+typedef PJRT_Error* PJRT_Memory_AddressableByDevices(
+ PJRT_Memory_AddressableByDevices_Args* args);
+
+// ------------------------------- Executables ---------------------------------
+
+struct PJRT_Executable_Destroy_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Destroy_Args, executable);
+
+// Frees `executable`. `executable` can be nullptr.
+typedef PJRT_Error* PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args* args);
+
+struct PJRT_LoadedExecutable_Destroy_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_LoadedExecutable* executable;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Destroy_Args, executable);
+
+// Frees `executable` and deletes the underlying runtime object as if
+// `PJRT_LoadedExecutable_Delete` were called. `executable` can be nullptr.
+typedef PJRT_Error* PJRT_LoadedExecutable_Destroy(
+ PJRT_LoadedExecutable_Destroy_Args* args);
+
+struct PJRT_LoadedExecutable_GetExecutable_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_LoadedExecutable* loaded_executable;
+ PJRT_Executable* executable; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_GetExecutable_Args, executable);
+
+// Constructs a PJRT_Executable from a PJRT_LoadedExecutable. The returned
+// executable should be freed by the caller with PJRT_Executable_Destroy.
+typedef PJRT_Error* PJRT_LoadedExecutable_GetExecutable(
+ PJRT_LoadedExecutable_GetExecutable_Args* args);
+
+struct PJRT_Executable_Name_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ // `executable_name` has the same lifetime as `executable`. It is owned by
+ // `executable`.
+ const char* executable_name; // out
+ size_t executable_name_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Name_Args, executable_name_size);
+
+// Returns a string that identifies the executable.
+typedef PJRT_Error* PJRT_Executable_Name(PJRT_Executable_Name_Args* args);
+
+// TODO(b/269178731): Revisit whether num_replicas is needed.
+struct PJRT_Executable_NumReplicas_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ size_t num_replicas; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_NumReplicas_Args, num_replicas);
+
+// Returns the number of replicas of the executable.
+typedef PJRT_Error* PJRT_Executable_NumReplicas(
+ PJRT_Executable_NumReplicas_Args* args);
+
+struct PJRT_Executable_NumPartitions_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ size_t num_partitions; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_NumPartitions_Args, num_partitions);
+
+// Returns the number of partitions of the executable.
+typedef PJRT_Error* PJRT_Executable_NumPartitions(
+ PJRT_Executable_NumPartitions_Args* args);
+
+struct PJRT_LoadedExecutable_AddressableDevices_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_LoadedExecutable* executable;
+ PJRT_Device* const* addressable_devices; // out
+ size_t num_addressable_devices; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_AddressableDevices_Args,
+ num_addressable_devices);
+
+// Returns a list of devices this executable will run on.
+typedef PJRT_Error* PJRT_LoadedExecutable_AddressableDevices(
+ PJRT_LoadedExecutable_AddressableDevices_Args* args);
+
+struct PJRT_Executable_OptimizedProgram_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ PJRT_Program* program; // out, but read below
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_OptimizedProgram_Args, program);
+
+// Retrieves the optimized program for a given PJRT_Executable (SPMD).
+// The caller should populate `program->format` and `format_size`.
+//
+// The implementation will set `program->format` and `program->format_size`
+// to inform callers of the format of the optimized program returned.
+// These members are owned by the implementation.
+//
+// If called with nullptr as `program->code`, `PJRT_Executable_OptimizedProgram`
+// will populate `program->code_size` as an output indicating the number of
+// bytes the string `program->code` requires.
+//
+// If `program->code` is not null, `PJRT_Executable_OptimizedProgram` will fill
+// the buffer pointed to by `program->code` with the serialization of the
+// optimized HLO program. `program->code` must point to a client-owned buffer of
+// size >= `program->code_size`, which must be at large enough to hold the
+// serialization of the optimized program.
+//
+// Callers should generally call this function twice with the same `args`.
+// In the first call, `program->code` must be nullptr. This call will populate
+// `program->code_size`. Clients should then allocate a buffer `code_buff` of at
+// least `code_size` bytes. Before the second call, callers should set
+// `program->code = code_buff`. The second call will then write the serialized
+// program to `code_buff`.
+typedef PJRT_Error* PJRT_Executable_OptimizedProgram(
+ PJRT_Executable_OptimizedProgram_Args* args);
+
+struct PJRT_LoadedExecutable_Delete_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_LoadedExecutable* executable;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Delete_Args, executable);
+
+// Drops `executable`'s reference to the internal runtime object and
+// associated resources, without freeing the `executable` object itself.
+// `executable` can only be used with PJRT_LoadedExecutable_IsDeleted and
+// PJRT_LoadedExecutable_Destroy after calling this method. The internal runtime
+// executable will be freed after the last execution completes.
+typedef PJRT_Error* PJRT_LoadedExecutable_Delete(
+ PJRT_LoadedExecutable_Delete_Args* args);
+
+struct PJRT_LoadedExecutable_IsDeleted_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_LoadedExecutable* executable;
+ bool is_deleted; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_IsDeleted_Args, is_deleted);
+
+// True if and only if PJRT_LoadedExecutable_Delete has previously been called.
+typedef PJRT_Error* PJRT_LoadedExecutable_IsDeleted(
+ PJRT_LoadedExecutable_IsDeleted_Args* args);
+
+typedef struct PJRT_Chunk {
+ void* data;
+ size_t size;
+ void (*deleter)(void* data, void* deleter_arg);
+ // `deleter_arg` will be passed to `deleter` as `deleter_arg` argument.
+ void* deleter_arg;
+} PJRT_Chunk;
+
+// TODO(b/263390934) implement C API that calls `AddChunk` and other
+// `xla::CopyToDeviceStream`.
+typedef struct PJRT_CopyToDeviceStream PJRT_CopyToDeviceStream;
+
+struct PJRT_TransferMetadata;
+
+// Returns PJRT_Error* created by PJRT_CallbackError in case of error.
+// Otherwise, returns nullptr. The callback must call
+// `chunk->deleter(chunk->data, chunk->deleter_arg)` when it's finished with
+// `chunk`.
+typedef PJRT_Error* (*PJRT_SendCallback)(PJRT_Chunk* chunk,
+ PJRT_CallbackError* callback_error,
+ size_t total_size_in_bytes, bool done,
+ void* user_arg);
+// The callback takes the ownership of the stream object. The callback must call
+// `PJRT_CopyToDeviceStream_Destroy` when it is done with the stream.
+typedef void (*PJRT_RecvCallback)(PJRT_CopyToDeviceStream* stream,
+ void* user_arg);
+
+struct PJRT_SendCallbackInfo {
+ // Used to associate this callback with the correct send op.
+ int64_t channel_id;
+ // Will be passed to `send_callback` as `user_arg` argument.
+ void* user_arg;
+ PJRT_SendCallback send_callback;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_SendCallbackInfo, send_callback);
+
+struct PJRT_RecvCallbackInfo {
+ // Used to associate this callback with the correct recv op.
+ int64_t channel_id;
+ // Will be passed to `recv_callback` as `user_arg` argument.
+ void* user_arg;
+ PJRT_RecvCallback recv_callback;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_RecvCallbackInfo, recv_callback);
+
+struct PJRT_ExecuteOptions {
+ size_t struct_size;
+ void* priv;
+ // Callbacks for when send/recv ops are executed. The outer lists correspond
+ // to each device returned by `PJRT_Executable_AddressableDevices` for
+ // `executable` (i.e. they will have length `num_devices`). Each inner list
+ // contains callback info for each send/recv op in `executable`; the order
+ // doesn't matter as the channel IDs are used instead. The callbacks can be
+ // stateful and the user code is responsible for managing state. The callback
+ // functions must outlive the execution (but not the info structs or lists).
+ PJRT_SendCallbackInfo** send_callbacks;
+ PJRT_RecvCallbackInfo** recv_callbacks;
+ size_t num_send_ops;
+ size_t num_recv_ops;
+ // If non-zero, identifies this execution as part of a potentially
+ // multi-device launch. This can be used to detect scheduling errors, e.g. if
+ // multi-host programs are launched in different orders on different hosts,
+ // the launch IDs may be used by the runtime to detect the mismatch.
+ int launch_id;
+ // A list of indices denoting the input buffers that should not be donated.
+ // An input buffer may be non-donable, for example, if it is referenced more
+ // than once. Since such runtime information is not available at compile time,
+ // the compiler might mark the input as `may-alias`, which could lead PjRt to
+ // donate the input buffer when it should not. By defining this list of
+ // indices, a higher-level PJRT caller can instruct PJRT client not to donate
+ // specific input buffers. The caller needs to make sure to keep it alive
+ // during the call.
+ const int64_t* non_donatable_input_indices;
+ size_t num_non_donatable_input_indices;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_ExecuteOptions, launch_id);
+
+struct PJRT_LoadedExecutable_Execute_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_LoadedExecutable* executable;
+ // Only needs to stay alive for the duration of the Execute call.
+ PJRT_ExecuteOptions* options;
+ // Execution input of size [`num_devices`, `num_args`].
+ PJRT_Buffer* const* const* argument_lists;
+ size_t num_devices;
+ size_t num_args;
+ // Execution output of size [`num_devices`, num_outputs`], where `num_outputs`
+ // is the number of outputs returned by this executable per device. Both the
+ // outer (`PJRT_Buffer***`) and inner lists (`PJRT_Buffer**`) must be
+ // allocated and deallocated by the caller. PJRT_Buffer_Destroy must be called
+ // on the output PJRT_Buffer*.
+ PJRT_Buffer** const* output_lists; // in/out
+ // If `device_complete_events` isn't nullptr, `device_complete_events` needs
+ // to be the same length as `output_lists` (i.e. of length `num_devices`), and
+ // each `PJRT_Event` will become ready once the corresponding device execution
+ // is complete. If Execute returns an error, then `device_complete_events`
+ // will not be populated. The caller is responsible for calling
+ // PJRT_Event_Destroy on the returned PJRT_Event*s.
+ PJRT_Event** device_complete_events; // in/out
+ // The device to execute on. If nullptr, will execute on the device(s)
+ // specified at compile time. If set, must be an addressable device, and
+ // `num_devices` should be 1 with `argument_lists` only containing arguments
+ // for `execute_device`. Can be set with a multi-device executable to launch
+ // just on this device. In this case, it's the responsibility of the caller to
+ // make sure the executable is launched on all participating devices specified
+ // at compile time. Setting this field may not be supported on all platforms
+ // or executables.
+ PJRT_Device* execute_device;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Execute_Args, execute_device);
+
+// Executes on devices addressable by the client.
+typedef PJRT_Error* PJRT_LoadedExecutable_Execute(
+ PJRT_LoadedExecutable_Execute_Args* args);
+
+struct PJRT_Executable_NumOutputs_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ size_t num_outputs; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_NumOutputs_Args, num_outputs);
+
+// Gets the number of outputs per device produced by `executable`.
+typedef PJRT_Error* PJRT_Executable_NumOutputs(
+ PJRT_Executable_NumOutputs_Args* args);
+
+struct PJRT_Executable_SizeOfGeneratedCodeInBytes_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ int64_t size_in_bytes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_SizeOfGeneratedCodeInBytes_Args,
+ size_in_bytes); // last field in the struct
+
+typedef PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes(
+ PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args);
+
+struct PJRT_Executable_Fingerprint_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ // Has the lifetime of `executable`
+ const char* executable_fingerprint; // out
+ size_t executable_fingerprint_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Fingerprint_Args,
+ executable_fingerprint_size);
+
+// A unique fingerprint for `executable`. Two executables that were produced by
+// compiling with identical inputs (same program, compile options, compiler
+// version, etc.) should have the same fingerprint. May not be implemented by
+// all platforms.
+typedef PJRT_Error* PJRT_Executable_Fingerprint(
+ PJRT_Executable_Fingerprint_Args* args);
+
+struct PJRT_Executable_GetCostAnalysis_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ size_t num_properties; // out
+ // `properties` and any embedded data are owned by and have the same lifetime
+ // as `executable`.
+ const PJRT_NamedValue* properties; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_GetCostAnalysis_Args, properties);
+
+// Get the cost properties for the executable. Different platforms may return
+// different properties; for example, some platforms may return the number of
+// operations, or memory size of the input/output of the executable, based on
+// program analysis.
+typedef PJRT_Error* PJRT_Executable_GetCostAnalysis(
+ PJRT_Executable_GetCostAnalysis_Args* args);
+
+struct PJRT_Executable_GetCompiledMemoryStats_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+
+ // Mirrors xla::CompiledMemoryStats.
+ int64_t generated_code_size_in_bytes; // out
+ int64_t argument_size_in_bytes; // out
+ int64_t output_size_in_bytes; // out
+ // How much argument is reused for output.
+ int64_t alias_size_in_bytes; // out
+ int64_t temp_size_in_bytes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_GetCompiledMemoryStats_Args,
+ temp_size_in_bytes);
+
+// Return memory stats that allow callers to estimate device memory usage
+// when running this executable.
+typedef PJRT_Error* PJRT_Executable_GetCompiledMemoryStats(
+ PJRT_Executable_GetCompiledMemoryStats_Args* args);
+
+struct PJRT_Executable_OutputElementTypes_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ PJRT_Buffer_Type* output_types; // out
+ size_t num_output_types; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_OutputElementTypes_Args,
+ num_output_types);
+
+// Returns a list of element types for outputs.
+typedef PJRT_Error* PJRT_Executable_OutputElementTypes(
+ PJRT_Executable_OutputElementTypes_Args* args);
+
+struct PJRT_Executable_OutputDimensions_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ size_t num_outputs;
+ // Has length: sum of all elements in the list `dim_sizes`.
+ const int64_t* dims; // out
+ // Has length `num_outputs`.
+ const size_t* dim_sizes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_OutputDimensions_Args, dim_sizes);
+
+// Returns a list of dimensions for outputs. Each output has an array shape,
+// which is represented by a list of dimensions. The array shapes of all outputs
+// are concatenated into a single list of dimensions.
+typedef PJRT_Error* PJRT_Executable_OutputDimensions(
+ PJRT_Executable_OutputDimensions_Args* args);
+
+struct PJRT_Executable_OutputMemoryKinds_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ size_t num_outputs;
+ // Has length `num_outputs`.
+ const char* const* memory_kinds; // out
+ // Has length `num_outputs`.
+ const size_t* memory_kind_sizes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_OutputMemoryKinds_Args,
+ memory_kind_sizes);
+
+// Returns a list of memory kind strings for outputs.
+typedef PJRT_Error* PJRT_Executable_OutputMemoryKinds(
+ PJRT_Executable_OutputMemoryKinds_Args* args);
+
+typedef struct PJRT_SerializedExecutable PJRT_SerializedExecutable;
+
+struct PJRT_Executable_Serialize_Args {
+ size_t struct_size;
+ void* priv;
+ const PJRT_Executable* executable;
+
+ // Lives only as long as serialized_executable
+ const char* serialized_bytes; // out
+ size_t serialized_bytes_size; // out
+
+ PJRT_SerializedExecutable* serialized_executable; // backs serialized_bytes.
+ // cleanup fn must be called to free the backing memory for serialized_bytes.
+ // Should only be called once on serialized_executable.
+ void (*serialized_executable_deleter)(
+ PJRT_SerializedExecutable* exec); // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Serialize_Args,
+ serialized_executable_deleter);
+
+// Returns a platform-specific serialization of `executable`. The serialization
+// is not guaranteed to be stable over time.
+typedef PJRT_Error* PJRT_Executable_Serialize(
+ PJRT_Executable_Serialize_Args* args);
+
+struct PJRT_Executable_DeserializeAndLoad_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ const char* serialized_executable;
+ size_t serialized_executable_size;
+ PJRT_LoadedExecutable* loaded_executable; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_DeserializeAndLoad_Args,
+ loaded_executable);
+
+// Deserializes an executable serialized by `PJRT_Executable_Serialize`.
+// `serialized_executable` must have been produced by the same platform and
+// library version as this one.
+typedef PJRT_Error* PJRT_Executable_DeserializeAndLoad(
+ PJRT_Executable_DeserializeAndLoad_Args* args);
+
+struct PJRT_LoadedExecutable_Fingerprint_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_LoadedExecutable* executable;
+ // Has the lifetime of `executable`
+ const char* executable_fingerprint; // out
+ size_t executable_fingerprint_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Fingerprint_Args,
+ executable_fingerprint_size);
+// DEPRECATED. Will be removed in PJRT version 2.0. Please use
+// PJRT_Executable_Fingerprint instead. A unique fingerprint for `executable`.
+// Two executables that were produced by compiling with identical inputs (same
+// program, compile options, compiler version, etc.) should have the same
+// fingerprint. May not be implemented by all platforms.
+typedef PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
+ PJRT_LoadedExecutable_Fingerprint_Args* args);
+
+// ---------------------------------- Buffers ----------------------------------
+
+struct PJRT_Buffer_Destroy_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Destroy_Args, buffer);
+
+// Deletes the underlying runtime objects as if 'PJRT_Buffer_Delete' were
+// called and frees `buffer`. `buffer` can be nullptr.
+typedef PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args);
+
+struct PJRT_Buffer_ElementType_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ PJRT_Buffer_Type type; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_ElementType_Args, type);
+
+// Returns the type of the array elements of a buffer.
+typedef PJRT_Error* PJRT_Buffer_ElementType(PJRT_Buffer_ElementType_Args* args);
+
+struct PJRT_Buffer_Dimensions_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ // Has the lifetime of `buffer` and length `num_dims`.
+ const int64_t* dims; // out
+ size_t num_dims; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Dimensions_Args, num_dims);
+
+// Returns the array shape of `buffer`, i.e. the size of each dimension.
+typedef PJRT_Error* PJRT_Buffer_Dimensions(PJRT_Buffer_Dimensions_Args* args);
+
+struct PJRT_Buffer_UnpaddedDimensions_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ // Has the lifetime of `buffer` and length `num_dims`.
+ const int64_t* unpadded_dims; // out
+ size_t num_dims; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_UnpaddedDimensions_Args, num_dims);
+
+// Returns the unpadded array shape of `buffer`. This usually is equivalent to
+// PJRT_Buffer_Dimensions, but for implementations that support
+// dynamically-sized dimensions via padding to a fixed size, any dynamic
+// dimensions may have a smaller unpadded size than the padded size reported by
+// PJRT_Buffer_Dimensions. ("Dynamic" dimensions are those whose length is
+// only known at runtime, vs. "static" dimensions whose size is fixed at compile
+// time.)
+typedef PJRT_Error* PJRT_Buffer_UnpaddedDimensions(
+ PJRT_Buffer_UnpaddedDimensions_Args* args);
+
+struct PJRT_Buffer_DynamicDimensionIndices_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ // Has the lifetime of `buffer` and length `num_dynamic_dims`.
+ const size_t* dynamic_dim_indices; // out
+ size_t num_dynamic_dims; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_DynamicDimensionIndices_Args,
+ num_dynamic_dims);
+
+// Returns the indices of dynamically-sized dimensions, or an empty list if all
+// dimensions are static. ("Dynamic" dimensions are those whose length is
+// only known at runtime, vs. "static" dimensions whose size is fixed at compile
+// time.)
+typedef PJRT_Error* PJRT_Buffer_DynamicDimensionIndices(
+ PJRT_Buffer_DynamicDimensionIndices_Args* args);
+
+struct PJRT_Buffer_GetMemoryLayout_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ // Layout data is owned by and has the lifetime of `buffer`.
+ PJRT_Buffer_MemoryLayout layout; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_GetMemoryLayout_Args, layout);
+
+// Returns the memory layout of the data in this buffer.
+typedef PJRT_Error* PJRT_Buffer_GetMemoryLayout(
+ PJRT_Buffer_GetMemoryLayout_Args* args);
+
+struct PJRT_Buffer_ToHostBuffer_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* src;
+
+ // The caller can specify an optional host layout. If nullptr, the layout of
+ // the src buffer will be used. The caller is responsible to keep the data
+ // (tiled or strides) in the host_layout alive during the call.
+ PJRT_Buffer_MemoryLayout* host_layout;
+ // `dst` can be nullptr to query required size which will be set into
+ // `dst_size`.
+ void* dst; // in/out
+ // Size of `dst` in bytes. If `dst` is nullptr, then `dst_size` is set to the
+ // size needed. Otherwise, `dst_size` must be greater than or equal to the
+ // needed size.
+ size_t dst_size; // in/out
+
+ // Event that signals when the copy has completed.
+ PJRT_Event* event; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_ToHostBuffer_Args, event);
+
+// Asynchronously copies the buffer's value into a preallocated host buffer.
+typedef PJRT_Error* PJRT_Buffer_ToHostBuffer(
+ PJRT_Buffer_ToHostBuffer_Args* args);
+
+struct PJRT_Buffer_OnDeviceSizeInBytes_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ size_t on_device_size_in_bytes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_OnDeviceSizeInBytes_Args,
+ on_device_size_in_bytes);
+
+// Gets the number of bytes of the buffer storage on the device
+typedef PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes(
+ PJRT_Buffer_OnDeviceSizeInBytes_Args* args);
+
+struct PJRT_Buffer_Delete_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Delete_Args, buffer);
+
+// Drop the buffer's reference to its associated device memory, without freeing
+// the `buffer` object itself. `buffer` can only be used with
+// PJRT_Buffer_IsDeleted and PJRT_Buffer_Destroy after calling this method. The
+// device memory will be freed when all async operations using the buffer have
+// completed, according to the allocation semantics of the underlying platform.
+typedef PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args);
+
+struct PJRT_Buffer_IsDeleted_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ bool is_deleted; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_IsDeleted_Args, is_deleted);
+
+// True if and only if PJRT_Buffer_Delete has previously been called.
+typedef PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args);
+
+struct PJRT_Buffer_CopyToDevice_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ PJRT_Device* dst_device;
+ PJRT_Buffer* dst_buffer; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_CopyToDevice_Args, dst_buffer);
+
+// Copies the buffer to device `dst_device` within the same client. Caller is
+// responsible for freeing returned `dst_buffer` with PJRT_Buffer_Destroy.
+// Returns an error if the buffer is already on `dst_device`.
+typedef PJRT_Error* PJRT_Buffer_CopyToDevice(
+ PJRT_Buffer_CopyToDevice_Args* args);
+
+struct PJRT_Buffer_CopyToMemory_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ PJRT_Memory* dst_memory;
+ PJRT_Buffer* dst_buffer; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_CopyToMemory_Args, dst_buffer);
+
+// Copies the buffer to memory `dst_memory` within the same client. Caller is
+// responsible for freeing returned `dst_buffer` with PJRT_Buffer_Destroy.
+// Returns an error if the buffer is already on `dst_memory`.
+typedef PJRT_Error* PJRT_Buffer_CopyToMemory(
+ PJRT_Buffer_CopyToMemory_Args* args);
+
+struct PJRT_Buffer_IsOnCpu_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ bool is_on_cpu; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_IsOnCpu_Args, is_on_cpu);
+
+// Whether this buffer is on CPU and thus allows for certain optimizations.
+typedef PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args);
+
+struct PJRT_Buffer_Device_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ PJRT_Device* device; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Device_Args, device);
+
+// Returns this buffer's storage device.
+typedef PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args);
+
+struct PJRT_Buffer_Memory_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ PJRT_Memory* memory; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Memory_Args, memory);
+
+// Returns this buffer's storage memory.
+typedef PJRT_Error* PJRT_Buffer_Memory(PJRT_Buffer_Memory_Args* args);
+
+struct PJRT_Buffer_ReadyEvent_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ // The caller is responsible for calling PJRT_Event_Destroy on `event`.
+ PJRT_Event* event; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_ReadyEvent_Args, event);
+
+// Returns an event that is triggered when either of the following happens:
+// * the data in the PJRT_Buffer becomes ready, or
+// * an error has occurred.
+//
+// TODO(b/241967811): change these weird semantics
+// If the buffer has been deleted or donated, the returned event will
+// immediately indicate an error. However, if PJRT_Buffer_ReadyEvent() is
+// called on the buffer before PJRT_Buffer_Delete() is, the returned event will
+// not transition to an error state after PJRT_Buffer_Delete() is called.
+typedef PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args);
+
+struct PJRT_Buffer_UnsafePointer_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ uintptr_t buffer_pointer; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_UnsafePointer_Args, buffer_pointer);
+
+// Returns platform-dependent address for the given buffer that is often but
+// not guaranteed to be the physical/device address.
+typedef PJRT_Error* PJRT_Buffer_UnsafePointer(
+ PJRT_Buffer_UnsafePointer_Args* args);
+
+struct PJRT_Buffer_IncreaseExternalReferenceCount_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_IncreaseExternalReferenceCount_Args,
+ buffer);
+
+// Increments the reference count for the buffer. The reference count indicates
+// the raw buffer data is being shared with another framework (e.g. NumPy,
+// dlpack) and should not be deleted or moved by the PJRT implementation (e.g.
+// for memory compaction). TODO(b/295230663): document more API contract
+// details, e.g. does this block, can the buffer be modified in-place.
+typedef PJRT_Error* PJRT_Buffer_IncreaseExternalReferenceCount(
+ PJRT_Buffer_IncreaseExternalReferenceCount_Args* args);
+
+struct PJRT_Buffer_DecreaseExternalReferenceCount_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_DecreaseExternalReferenceCount_Args,
+ buffer);
+
+// Decrements the reference count for the buffer. Returns an error if the
+// reference count is zero (i.e. PJRT_Buffer_IncreaseExternalReferenceCount is
+// not called beforehand).
+typedef PJRT_Error* PJRT_Buffer_DecreaseExternalReferenceCount(
+ PJRT_Buffer_DecreaseExternalReferenceCount_Args* args);
+
+struct PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ void* device_memory_ptr; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args,
+ device_memory_ptr);
+
+// Returns the opaque device memory data pointer of the buffer. The returned
+// data pointer may become invalid at any point unless the external reference
+// count is greater than 0 via PJRT_Buffer_IncreaseExternalReferenceCount.
+typedef PJRT_Error* PJRT_Buffer_OpaqueDeviceMemoryDataPointer(
+ PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args* args);
+
+// ---------------------------- CopyToDeviceStream -----------------------------
+
+struct PJRT_CopyToDeviceStream_Destroy_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_CopyToDeviceStream* stream;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_Destroy_Args, stream);
+
+// Frees `stream`. `stream` can be nullptr.
+typedef PJRT_Error* PJRT_CopyToDeviceStream_Destroy(
+ PJRT_CopyToDeviceStream_Destroy_Args* args);
+
+struct PJRT_CopyToDeviceStream_AddChunk_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_CopyToDeviceStream* stream;
+ // Takes ownership of `chunk` (i.e. implementation will call chunk.deleter).
+ PJRT_Chunk* chunk;
+ PJRT_Event* transfer_complete; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_AddChunk_Args,
+ transfer_complete);
+
+// Emplaces a new chunk of data to copy to the device. The transfer is started
+// immediately, and the returned event is triggered when the transfer completes
+// or fails.
+//
+// The returned event will indicate an error if the chunk's size causes the
+// amount of transferred data to exceed the total bytes, if the stream is
+// already complete, or if the chunk is not a multiple of the granule size.
+typedef PJRT_Error* PJRT_CopyToDeviceStream_AddChunk(
+ PJRT_CopyToDeviceStream_AddChunk_Args* args);
+
+struct PJRT_CopyToDeviceStream_TotalBytes_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_CopyToDeviceStream* stream;
+ int64_t total_bytes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_TotalBytes_Args, total_bytes);
+
+// Returns the total amount of data the stream expects to be transferred.
+typedef PJRT_Error* PJRT_CopyToDeviceStream_TotalBytes(
+ PJRT_CopyToDeviceStream_TotalBytes_Args* args);
+
+struct PJRT_CopyToDeviceStream_GranuleSize_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_CopyToDeviceStream* stream;
+ int64_t granule_size_in_bytes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_GranuleSize_Args,
+ granule_size_in_bytes);
+
+// Returns the granule size in bytes. The size of the chunk added to this stream
+// must be a multiple of this number.
+typedef PJRT_Error* PJRT_CopyToDeviceStream_GranuleSize(
+ PJRT_CopyToDeviceStream_GranuleSize_Args* args);
+
+struct PJRT_CopyToDeviceStream_CurrentBytes_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_CopyToDeviceStream* stream;
+ int64_t current_bytes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_CurrentBytes_Args,
+ current_bytes);
+
+// Returns the amount of data the stream currently has either transferred or has
+// buffered to transfer.
+typedef PJRT_Error* PJRT_CopyToDeviceStream_CurrentBytes(
+ PJRT_CopyToDeviceStream_CurrentBytes_Args* args);
+
+// ------------------------------ Device Topology ------------------------------
+
+struct PJRT_TopologyDescription_Create_Args {
+ size_t struct_size;
+ void* priv;
+ const char* topology_name;
+ size_t topology_name_size;
+ // Extra platform-specific options to create a client.
+ const PJRT_NamedValue* create_options;
+ size_t num_options;
+ PJRT_TopologyDescription* topology; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Create_Args, topology);
+
+// Creates and initializes a new PJRT_TopologyDescription and returns in
+// `topology`.
+typedef PJRT_Error* PJRT_TopologyDescription_Create(
+ PJRT_TopologyDescription_Create_Args* args);
+
+struct PJRT_TopologyDescription_Destroy_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_TopologyDescription* topology;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Destroy_Args, topology);
+
+// Frees `topology`. `topology` can be nullptr.
+typedef PJRT_Error* PJRT_TopologyDescription_Destroy(
+ PJRT_TopologyDescription_Destroy_Args* args);
+
+struct PJRT_TopologyDescription_PlatformVersion_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_TopologyDescription* topology;
+ // `platform_version` has the same lifetime as `topology`. It's owned by
+ // `topology`.
+ const char* platform_version; // out
+ size_t platform_version_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_PlatformVersion_Args,
+ platform_version_size);
+
+// Returns a string containing human-readable, platform-specific version info
+// (e.g. the CUDA version on GPU or libtpu version on Cloud TPU).
+typedef PJRT_Error* PJRT_TopologyDescription_PlatformVersion(
+ PJRT_TopologyDescription_PlatformVersion_Args* args);
+
+struct PJRT_TopologyDescription_PlatformName_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_TopologyDescription* topology;
+ // `platform_name` has the same lifetime as `topology`. It is owned by
+ // `topology`.
+ const char* platform_name; // out
+ size_t platform_name_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_PlatformName_Args,
+ platform_name_size);
+
+// Returns a string that identifies the platform (e.g. "cpu", "gpu", "tpu").
+typedef PJRT_Error* PJRT_TopologyDescription_PlatformName(
+ PJRT_TopologyDescription_PlatformName_Args* args);
+
+struct PJRT_TopologyDescription_GetDeviceDescriptions_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_TopologyDescription* topology;
+ // Has the same lifetime as topology.
+ PJRT_DeviceDescription* const* descriptions; // out
+ size_t num_descriptions; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_GetDeviceDescriptions_Args,
+ num_descriptions);
+
+// Returns descriptions for all devices in this topology. The device
+// descriptions can be returned in any order, but will be in the same order
+// across calls within a process.
+typedef PJRT_Error* PJRT_TopologyDescription_GetDeviceDescriptions(
+ PJRT_TopologyDescription_GetDeviceDescriptions_Args* args);
+
+typedef struct PJRT_SerializedTopology PJRT_SerializedTopology;
+
+struct PJRT_TopologyDescription_Serialize_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_TopologyDescription* topology;
+
+ // Lives only as long as serialized_topology.
+ const char* serialized_bytes; // out
+ size_t serialized_bytes_size; // out
+
+ PJRT_SerializedTopology* serialized_topology; // out
+ // Must be called exactly once to free the backing memory for
+ // serialized_bytes.
+ void (*serialized_topology_deleter)(
+ PJRT_SerializedTopology* serialized_topology); // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Serialize_Args,
+ serialized_topology_deleter);
+
+// Serializes the TopologyDescription to a string for use in cache keys.
+typedef PJRT_Error* PJRT_TopologyDescription_Serialize(
+ PJRT_TopologyDescription_Serialize_Args* args);
+
+struct PJRT_TopologyDescription_Attributes_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_TopologyDescription* topology;
+
+ // Only lives as long as topology.
+ const PJRT_NamedValue* attributes; // out
+ size_t num_attributes; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Attributes_Args,
+ num_attributes);
+
+// Returns platform-specific topology attributes.
+typedef PJRT_Error* PJRT_TopologyDescription_Attributes(
+ PJRT_TopologyDescription_Attributes_Args* args);
+
+struct PJRT_Compile_Args {
+ size_t struct_size;
+ void* priv;
+ const PJRT_TopologyDescription* topology;
+ // Only needs to stay alive for the duration of the Compile call.
+ // `program->format` and `program->format_size` are owned by the caller.
+ const PJRT_Program* program;
+ // TODO(b/240560013): consider putting some of option fields in priv.
+ // Serialized CompileOptionsProto
+ // (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/compile_options.proto)
+ const char* compile_options;
+ size_t compile_options_size;
+ // Optionally provided for performance-guided optimizations.
+ PJRT_Client* client;
+ PJRT_Executable* executable; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Compile_Args, executable);
+
+// Compiles a program in specified format (such as MLIR or HLO) with given
+// `options`. The returned executable must be loaded by a compatible
+// PJRT_Client before execution.
+typedef PJRT_Error* PJRT_Compile(PJRT_Compile_Args* args);
+
+// -------------------------------- Extension ----------------------------------
+
+typedef enum {
+ PJRT_Structure_Type_Gpu_Custom_Call = 0,
+ PJRT_Structure_Type_Profiler,
+} PJRT_Structure_Type;
+
+// PJRT_Structure_Base contains a type and a pointer to next
+// PJRT_Structure_Base. The framework can go through this chain to find
+// structure and identify it with the type.
+typedef struct PJRT_Structure_Base {
+ PJRT_Structure_Type type;
+ const struct PJRT_Structure_Base* next;
+} PJRT_Structure_Base;
+
+// -------------------------------- API access ---------------------------------
+
+#define _PJRT_API_STRUCT_FIELD(fn_type) fn_type* fn_type
+
+// Please modify PJRT_Api_STRUCT_SIZE if the last field of PJRT_Api is changed.
+typedef struct PJRT_Api {
+ size_t struct_size;
+ void* extension_start;
+
+ PJRT_Api_Version pjrt_api_version;
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Error_Destroy);
+ _PJRT_API_STRUCT_FIELD(PJRT_Error_Message);
+ _PJRT_API_STRUCT_FIELD(PJRT_Error_GetCode);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Plugin_Initialize);
+ _PJRT_API_STRUCT_FIELD(PJRT_Plugin_Attributes);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Event_Destroy);
+ _PJRT_API_STRUCT_FIELD(PJRT_Event_IsReady);
+ _PJRT_API_STRUCT_FIELD(PJRT_Event_Error);
+ _PJRT_API_STRUCT_FIELD(PJRT_Event_Await);
+ _PJRT_API_STRUCT_FIELD(PJRT_Event_OnReady);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_Create);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_Destroy);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformName);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_ProcessIndex);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformVersion);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_Devices);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_AddressableDevices);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_LookupDevice);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_LookupAddressableDevice);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_AddressableMemories);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_Compile);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_DefaultDeviceAssignment);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_BufferFromHostBuffer);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_Id);
+ _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_ProcessIndex);
+ _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_Attributes);
+ _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_Kind);
+ _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_DebugString);
+ _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_ToString);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Device_GetDescription);
+ _PJRT_API_STRUCT_FIELD(PJRT_Device_IsAddressable);
+ _PJRT_API_STRUCT_FIELD(PJRT_Device_LocalHardwareId);
+ _PJRT_API_STRUCT_FIELD(PJRT_Device_AddressableMemories);
+ _PJRT_API_STRUCT_FIELD(PJRT_Device_DefaultMemory);
+ _PJRT_API_STRUCT_FIELD(PJRT_Device_MemoryStats);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Memory_Id);
+ _PJRT_API_STRUCT_FIELD(PJRT_Memory_Kind);
+ _PJRT_API_STRUCT_FIELD(PJRT_Memory_DebugString);
+ _PJRT_API_STRUCT_FIELD(PJRT_Memory_ToString);
+ _PJRT_API_STRUCT_FIELD(PJRT_Memory_AddressableByDevices);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_Destroy);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_Name);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_NumReplicas);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_NumPartitions);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_NumOutputs);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_SizeOfGeneratedCodeInBytes);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_GetCostAnalysis);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputMemoryKinds);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_OptimizedProgram);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_Serialize);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_Destroy);
+ _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_GetExecutable);
+ _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_AddressableDevices);
+ _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_Delete);
+ _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_IsDeleted);
+ _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_Execute);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_DeserializeAndLoad);
+ _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_Fingerprint);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Destroy);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_ElementType);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Dimensions);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_UnpaddedDimensions);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_DynamicDimensionIndices);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_GetMemoryLayout);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_OnDeviceSizeInBytes);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Device);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Memory);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Delete);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsDeleted);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToDevice);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_ToHostBuffer);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsOnCpu);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_ReadyEvent);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_UnsafePointer);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IncreaseExternalReferenceCount);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_DecreaseExternalReferenceCount);
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_OpaqueDeviceMemoryDataPointer);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_CopyToDeviceStream_Destroy);
+ _PJRT_API_STRUCT_FIELD(PJRT_CopyToDeviceStream_AddChunk);
+ _PJRT_API_STRUCT_FIELD(PJRT_CopyToDeviceStream_TotalBytes);
+ _PJRT_API_STRUCT_FIELD(PJRT_CopyToDeviceStream_GranuleSize);
+ _PJRT_API_STRUCT_FIELD(PJRT_CopyToDeviceStream_CurrentBytes);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_Create);
+ _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_Destroy);
+ _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_PlatformName);
+ _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_PlatformVersion);
+ _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_GetDeviceDescriptions);
+ _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_Serialize);
+ _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_Attributes);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Compile);
+
+ // Always add new fields to the end of the struct. Move fields below to their
+ // corresponding places after each major version bump.
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputElementTypes);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputDimensions);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToMemory);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_CreateViewOfDeviceBuffer);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_Fingerprint);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_TopologyDescription);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_GetCompiledMemoryStats);
+} PJRT_Api;
+
+enum {
+ PJRT_Api_STRUCT_SIZE =
+ PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_TopologyDescription)
+};
+
+#undef _PJRT_API_STRUCT_FIELD
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // XLA_PJRT_C_PJRT_C_API_H_