From b9f5c6733f209e73dcdc148725ebe55817294f4c Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Sat, 24 Feb 2024 21:50:16 +0100 Subject: [PATCH] XLA backend for Lc0 (#1949) (cherry picked from commit 04f73fc075fdccedb558d72d8f0ad5efb7d1f53a) --- meson.build | 16 + meson_options.txt | 5 + scripts/compile_proto.py | 1 + src/neural/xla/hlo.proto | 412 +++++++ src/neural/xla/hlo_builder.cc | 312 +++++ src/neural/xla/hlo_builder.h | 107 ++ src/neural/xla/network_xla.cc | 305 +++++ src/neural/xla/onnx2hlo.cc | 532 ++++++++ src/neural/xla/onnx2hlo.h | 73 ++ src/neural/xla/pjrt.cc | 403 ++++++ src/neural/xla/pjrt.h | 252 ++++ src/neural/xla/print_hlo.cc | 390 ++++++ src/neural/xla/print_hlo.h | 44 + src/neural/xla/xla_runner.cc | 214 ++++ src/neural/xla/xla_runner.h | 134 ++ third_party/pjrt_c_api.h | 2175 +++++++++++++++++++++++++++++++++ 16 files changed, 5375 insertions(+) create mode 100644 src/neural/xla/hlo.proto create mode 100644 src/neural/xla/hlo_builder.cc create mode 100644 src/neural/xla/hlo_builder.h create mode 100644 src/neural/xla/network_xla.cc create mode 100644 src/neural/xla/onnx2hlo.cc create mode 100644 src/neural/xla/onnx2hlo.h create mode 100644 src/neural/xla/pjrt.cc create mode 100644 src/neural/xla/pjrt.h create mode 100644 src/neural/xla/print_hlo.cc create mode 100644 src/neural/xla/print_hlo.h create mode 100644 src/neural/xla/xla_runner.cc create mode 100644 src/neural/xla/xla_runner.h create mode 100644 third_party/pjrt_c_api.h diff --git a/meson.build b/meson.build index a9ba5caf95..d7e138cef7 100644 --- a/meson.build +++ b/meson.build @@ -623,6 +623,22 @@ if get_option('build_backends') endif + ## ~~~~~~~~ + ## XLA + ## ~~~~~~~~ + if get_option('xla') + files += [ + 'src/neural/xla/hlo_builder.cc', + 'src/neural/xla/network_xla.cc', + 'src/neural/xla/onnx2hlo.cc', + 'src/neural/xla/print_hlo.cc', + 'src/neural/xla/pjrt.cc', + 'src/neural/xla/xla_runner.cc', + ] + files += gen_proto_src.process('src/neural/xla/hlo.proto', + preserve_path_from : meson.current_source_dir() + '/src/') + deps += cc.find_library('dl', required: false) + endif endif # if get_option('build_backends') diff --git a/meson_options.txt b/meson_options.txt index 645815a914..d5e74d976e 100644 --- a/meson_options.txt +++ b/meson_options.txt @@ -187,3 +187,8 @@ option('onnx_include', type: 'string', value: 'D:/IDE/Microsoft Visual Studio/Projects/lc0/subprojects/onnx/include', description: 'Paths to ONNX runtime includes') + +option('xla', + type: 'boolean', + value: false, + description: 'Enable XLA backend') \ No newline at end of file diff --git a/scripts/compile_proto.py b/scripts/compile_proto.py index 9163683a4e..cb7d0450b2 100755 --- a/scripts/compile_proto.py +++ b/scripts/compile_proto.py @@ -64,6 +64,7 @@ 'package', 'message', 'optional', + 'required', 'repeated', 'enum', ] + list(TYPES.keys()) diff --git a/src/neural/xla/hlo.proto b/src/neural/xla/hlo.proto new file mode 100644 index 0000000000..a374b25837 --- /dev/null +++ b/src/neural/xla/hlo.proto @@ -0,0 +1,412 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +syntax = "proto2"; + +package pblczero; + +message XlaLayoutProto { + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). This field is required. + repeated int64 minor_to_major = 1; +} + +message XlaShapeProto { + enum Type { + PRIMITIVE_TYPE_INVALID = 0; + + // Predicates are two-state booleans. + PRED = 1; + + // Signed integral values of fixed width. + S4 = 21; + S8 = 2; + S16 = 3; + S32 = 4; + S64 = 5; + + // Unsigned integral values of fixed width. + U4 = 22; + U8 = 6; + U16 = 7; + U32 = 8; + U64 = 9; + + // Floating-point values of fixed width. + // + // Note: if f16s are not natively supported on the device, they will be + // converted to f16 from f32 at arbirary points in the computation. + F16 = 10; + F32 = 11; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the + // exponent and 7 bits for the mantissa. + BF16 = 16; + + F64 = 12; + + // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2209.05433 + // + // F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the + // existing IEEE types. + // + // F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only + // Finite and NaN values are supported. Unlike IEEE types, infinities are + // not supported. NaN is represented when the exponent and mantissa bits + // are all 1s. All other values are finite. + // + // F8E4M3B11FNUZ has 4 exponent bits and 3 mantissa bits and a bias of 11. + // The "FNUZ" means only Finite and NaN values are supported; zero is + // unsigned. Unlike IEEE types, infinities are not supported. NaN is + // represented when the exponent and mantissa bits are all 0s with a sign + // bit of 1. All other values are finite. + + F8E5M2 = 19; + F8E4M3FN = 20; + F8E4M3B11FNUZ = 23; + + // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915 + // + // F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits. + // F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits. + // + // The "FNUZ" means only Finite and NaN values are supported; zero is + // unsigned. Unlike IEEE types, infinities are not supported. NaN is + // represented when the exponent and mantissa bits are all 0s with a sign + // bit of 1. All other values are finite. + // + // These differences mean there's an additional exponent value available. To + // keep the same dynamic range as an IEEE-like FP8 type, the exponent is + // biased one more than would be expected given the number of exponent bits + // (8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ). + F8E5M2FNUZ = 24; + F8E4M3FNUZ = 25; + + // Complex values of fixed width. + C64 = 15; // Paired F32 (real, imag), as in std::complex. + C128 = 18; // Paired F64 (real, imag), as in std::complex. + + // A tuple is a polymorphic sequence; e.g. a shape that holds different + // sub-shapes. They are used for things like returning multiple values from + // a computation; e.g. a computation that returns weights and biases may + // have a signature that results in a tuple like (f32[784x2000], f32[2000]) + // + // If a shape proto has the tuple element type, it may not have any entries + // in the dimensions field. + TUPLE = 13; + + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. + // + // (OPAQUE would be a better name for this identifier, but that conflicts + // with a macro defined in windows.h.) + OPAQUE_TYPE = 14; + + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + } + + // The element type for this shape. + required Type element_type = 2; + + // The size (number of elements) for each dimension, or an upper bound on the + // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 + // to N-1 for an N-dimensional array. The first element of 'dimensions' is the + // size of dimension 0, the second element is the size of dimension 1, and so + // forth. Empty list indicates a scalar. + // + // If the respective element in 'is_dimension_dynamic' is true then the value + // in this field represents an upper bound on the size of the dimension. + repeated int64 dimensions = 3; + + // For tuples only, the shapes of constituent shapes in the tuple sequence. + repeated XlaShapeProto tuple_shapes = 4; + + // The layout used to back this shape. + required XlaLayoutProto layout = 5; + + // For arrays, this indicates whether or not each dimension is + // dynamically-sized. The number of elements in this repeated field should be + // zero (indicating that no dimensions are dynamic) or equal to the number of + // elements in the 'dimensions' field. + repeated bool is_dynamic_dimension = 6; +} + +// Shape of the parameters and output of a computation (like a traditional +// function signature). +message XlaProgramShapeProto { + repeated XlaShapeProto parameters = 1; + required XlaShapeProto result = 2; + repeated string parameter_names = 3; +} + +// Symbolization metadata for HLO Instructions. +// +// This metadata is used for debugging XLA code generation, as well as +// performance profiling of XLA-generated executables. +message XlaOpMetadata { + // The framework op name that generated this XLA op. + // + // Frameworks that build on top of XLA should mirror the names of their ops + // back to users by specifying the op_type. In this way, even if the + // framework's "ops" are implemented as multiple XLA HLO Ops, they can be + // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as + // multiple ops, then each op should have the op_type be "SoftMax".) + optional string op_type = 1; + // The user-specified name of the op. + // + // This name is often unique within a computation. Note: some frameworks + // add auto-generated names if the user does not provide one. + optional string op_name = 2; + // Indicate a file and line that this op is associated to in a user's program. + // + // e.g. it could be the file and line of user code that generated the op. + optional string source_file = 3; + optional int32 source_line = 4; +} + +message XlaLiteralProto { + required XlaShapeProto shape = 1; + repeated bool preds = 2; + optional bytes s4s = 21; + optional bytes u4s = 22; + optional bytes s8s = 15; + optional bytes u8s = 3; + repeated int32 s32s = 4; + repeated int64 s64s = 5; + repeated uint32 u32s = 6; + repeated uint64 u64s = 7; + repeated float f32s = 8; + repeated double f64s = 9; + repeated float c64s = 12; // Stored as interleaved real, imag floats. + repeated double c128s = 18; // Stored as interleaved real, imag doubles. + repeated XlaLiteralProto tuple_literals = 10; + // The F16s, BF16s, U16s and S16s are encoded in little endian byte order + optional bytes f16s = 11; + optional bytes bf16s = 13; + optional bytes u16s = 16; + optional bytes s16s = 17; + optional bytes f8e5m2s = 19; + optional bytes f8e4m3fns = 20; + optional bytes f8e4m3b11fnuzs = 23; + optional bytes f8e5m2fnuzs = 24; + optional bytes f8e4m3fnuzs = 25; + repeated int64 sparse_indices = 14; + // Next = 26 +} + +message XlaWindowDimension { + optional int64 size = 1; + optional int64 stride = 2; + optional int64 padding_low = 3; + optional int64 padding_high = 4; + optional int64 window_dilation = 5; + optional int64 base_dilation = 6; + optional bool window_reversal = 7; +} + +message XlaWindow { repeated XlaWindowDimension dimensions = 1; } + +message XlaConvolutionDimensionNumbers { + optional int64 input_batch_dimension = 7; + optional int64 input_feature_dimension = 8; + repeated int64 input_spatial_dimensions = 11; + optional int64 kernel_input_feature_dimension = 3; + optional int64 kernel_output_feature_dimension = 4; + repeated int64 kernel_spatial_dimensions = 6; + optional int64 output_batch_dimension = 9; + optional int64 output_feature_dimension = 10; + repeated int64 output_spatial_dimensions = 12; +} + +message XlaDotDimensionNumbers { + repeated int64 lhs_contracting_dimensions = 1; + repeated int64 rhs_contracting_dimensions = 2; + repeated int64 lhs_batch_dimensions = 3; + repeated int64 rhs_batch_dimensions = 4; +} + +message HloInstructionProto { + required string name = 1; + required string opcode = 2; + required XlaShapeProto shape = 3; + + optional XlaOpMetadata metadata = 7; + + // Literal, only present for kConstant. + optional XlaLiteralProto literal = 8; + + // Parameter number is only present for kParameter. + optional int64 parameter_number = 9; + + // Index for kGetTupleElement. + optional int64 tuple_index = 13; + + // Describes the window in a windowed operation such as convolution. + optional XlaWindow window = 15; + + // Describes the dimension numbers used for a convolution. + optional XlaConvolutionDimensionNumbers convolution_dimension_numbers = 16; + + optional XlaDotDimensionNumbers dot_dimension_numbers = 30; + + // Dimensions present for some operations that require reshaping or + // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. + repeated int64 dimensions = 14; + + // The id of this instruction. + required int64 id = 35; + + repeated int64 operand_ids = 36; + repeated int64 called_computation_ids = 38; +} + +message HloComputationProto { + required string name = 1; + + // The array of instructions is always in a valid dependency order, where + // operands appear before their users. + repeated HloInstructionProto instructions = 2; + required XlaProgramShapeProto program_shape = 4; + + // The id of this computation. + required int64 id = 5; + + // The id of the root of the computation. + required int64 root_id = 6; +} + +message HloModuleProto { + required string name = 1; + required string entry_computation_name = 2; + required int64 entry_computation_id = 6; + + // The array of computations is always in a valid dependency order, where + // callees appear before their callers. + repeated HloComputationProto computations = 3; + + // The host program shape (with layout) of the entry computation. + required XlaProgramShapeProto host_program_shape = 4; + + // The id of this module. + required int64 id = 5; +} + +message OptionOverrideProto { + optional string string_field = 1; + optional bool bool_field = 2; + optional int64 int_field = 3; + optional double double_field = 4; +} + +message CompileEnvOptionProto { + required string key = 1; + required OptionOverrideProto value = 2; +} + +message ExecutableBuildOptionsProto { + // If set, this is the device to build the computation for. Valid + // device_ordinal values are: 0 to # of devices - 1. These values are + // identical to the device ordinal values used by StreamExecutor. The built + // executable will be executable on any device equivalent to the specified + // device as determined by Backend::devices_equivalent(). A value of -1 + // indicates this option has not been set. + optional int64 device_ordinal = 1; + + // If set, this specifies the layout of the result of the computation. If not + // set, the service will chose the layout of the result. A Shape is used to + // store the layout to accommodate tuple result shapes. A value of nullptr + // indicates the option has not been set. + optional XlaShapeProto result_layout = 2; + + // The number of replicas of this computation that are to be executed. + // Defaults to 1. + optional int64 num_replicas = 4; + + // The number of partitions in this computation. Defaults to 1. + optional int64 num_partitions = 5; + + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + optional bool use_spmd_partitioning = 6; + + // Whether to automatically generate XLA shardings for SPMD partitioner. + optional bool use_auto_spmd_partitioning = 7; + + // Whether HLOs should be deduplicated. + optional bool deduplicate_hlo = 8; + + // Whether input and output buffers are aliased if the associated parameter is + // passed-through XLA modules without being changed. + optional bool alias_passthrough_params = 10; + + // By default, XLA builds an executable by invoking standard compilation, i.e. + // running Compiler::Compile, or both Compiler::RunHloPasses and + // Compiler::RunBackend. When run_backend_only is set to true, XLA builds an + // executable by invoking only RunBackend and skip invoking RunHloPasses, + // which can be used to compile post-optimizations HLO modules. + optional bool run_backend_only = 11; + + // Allows sharding propagation to propagate to the outputs. This changes the + // output shape of the computation (which is undesirable), but it can be used + // to allow to run partial compilation to determine what would be the output + // sharding of a computation if XLA would be allowed to propagate the sharding + // which can be used by higher level framework as a way to query intermediate + // sharding of operations when multiple computation would be chained and + // merged together. + // This is a vector of bool, because the user can control (if the output of + // the computation is a tuple) which elements of the tuple can have the + // sharding substituted and which don't. If only one boolean value is passed + // in the vector that's interpreted as the value to be applied for every + // single element of the output tuple. One value per element of the tuple + // means that each value is attached to one of the output elements. + repeated bool allow_spmd_sharding_propagation_to_output = 12; + + // Opaque profile data for any feedback directed optimizations. + optional bytes fdo_profile = 14; + + optional int64 device_memory_size = 15; + + // Mesh shape in auto sharding options. + repeated int64 auto_spmd_partitioning_mesh_shape = 16; + + // Mesh ids in auto sharding options. + repeated int64 auto_spmd_partitioning_mesh_ids = 17; +} + +message CompileOptionsProto { + repeated XlaShapeProto argument_layouts = 1; + optional bool parameter_is_tupled_arguments = 2; + optional ExecutableBuildOptionsProto executable_build_options = 3; + optional bool compile_portable_executable = 4; + optional int64 profile_version = 5; + optional bytes serialized_multi_slice_config = 6; + repeated CompileEnvOptionProto env_options = 7; +} \ No newline at end of file diff --git a/src/neural/xla/hlo_builder.cc b/src/neural/xla/hlo_builder.cc new file mode 100644 index 0000000000..5178197863 --- /dev/null +++ b/src/neural/xla/hlo_builder.cc @@ -0,0 +1,312 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#include "neural/xla/hlo_builder.h" + +#include + +#include "utils/exception.h" +#include "utils/logging.h" + +namespace lczero { + +// Creates an instruction and populates required fields of the +// HloInstructionProto: result shape, opcode and operands. +// Appends the instruction to the entry computation. +pblczero::HloInstructionProto* HloBuilder::MakeInstruction( + std::string_view opcode, const pblczero::XlaShapeProto& shape, + const std::vector operands) { + auto instr = std::make_unique(); + auto ret = instr.get(); + ret->set_opcode(opcode); + *ret->mutable_shape() = shape; + *ret->mutable_metadata() = metadata_; + ret->set_id(entry_computation_.size()); + for (const auto& operand : operands) { + ret->add_operand_ids(operand->id()); + } + entry_computation_.push_back(std::move(instr)); + return ret; +} + +// Creates an elementwise instruction, which always have two operands of the +// same shape. +pblczero::HloInstructionProto* HloBuilder::MakeElementwiseInstruction( + std::string_view opcode, HloFlow lhs, HloFlow rhs) { + if (lhs->shape().dimensions() != rhs->shape().dimensions()) { + throw Exception("Elementwise operands must have the same shape"); + } + return MakeInstruction(opcode, lhs->shape(), {lhs, rhs}); +} + +//////////////////////////////////////////////////////////////////////////// +// Instructions. +//////////////////////////////////////////////////////////////////////////// + +HloFlow HloBuilder::Parameter(const pblczero::XlaShapeProto& shape) { + return MakeInstruction("parameter", shape, {}); +} + +// Converts the element types while keeping the shape. +HloFlow HloBuilder::Convert(HloFlow input, + const pblczero::XlaShapeProto::Type type) { + if (input->shape().element_type() == type) return input; + pblczero::XlaShapeProto shape = input->shape(); + shape.set_element_type(type); + return MakeInstruction("convert", shape, {input}); +} + +HloFlow HloBuilder::Constant(const pblczero::XlaLiteralProto& literal) { + auto* flow = MakeInstruction("constant", literal.shape(), {}); + *flow->mutable_literal() = literal; + return flow; +} + +HloFlow HloBuilder::Convolution( + HloFlow input, HloFlow filter, const pblczero::XlaWindow& window, + const pblczero::XlaConvolutionDimensionNumbers& dn) { + if (input->shape().dimensions_size() != filter->shape().dimensions_size()) { + throw Exception( + "Convolution input and filter shapes must have the " + "same number of dimensions"); + } + pblczero::XlaShapeProto shape = input->shape(); + auto* out_dims = shape.mutable_dimensions(); + const auto& in_dims = input->shape().dimensions(); + const auto& filter_dims = filter->shape().dimensions(); + (*out_dims)[dn.output_batch_dimension()] = + in_dims[dn.input_batch_dimension()]; + (*out_dims)[dn.output_feature_dimension()] = + filter_dims[dn.kernel_output_feature_dimension()]; + for (size_t i = 0; i < dn.input_spatial_dimensions_size(); ++i) { + (*out_dims)[dn.output_spatial_dimensions(i)] = + in_dims[dn.input_spatial_dimensions(i)]; + } + auto* flow = MakeInstruction("convolution", shape, {input, filter}); + *flow->mutable_window() = window; + *flow->mutable_convolution_dimension_numbers() = dn; + return flow; +} + +HloFlow HloBuilder::Broadcast( + HloFlow input, const pblczero::XlaShapeProto& target_shape, + const std::vector& broadcast_dimensions) { + auto flow = MakeInstruction("broadcast", target_shape, {input}); + if (broadcast_dimensions.size() != input->shape().dimensions_size()) { + throw Exception("Broadcast must have the same size as the input shape"); + } + const auto& input_shape = input->shape(); + for (size_t i = 0; i < broadcast_dimensions.size(); ++i) { + auto dim = broadcast_dimensions[i]; + const auto& input_dim = input_shape.dimensions(i); + if (input_dim != 1 && input_dim != target_shape.dimensions(dim)) { + throw Exception( + "Broadcast dimension must be 1 or equal to the target shape " + "dimension"); + } + flow->add_dimensions(dim); + } + return flow; +} + +HloFlow HloBuilder::Add(HloFlow lhs, HloFlow rhs) { + return MakeElementwiseInstruction("add", lhs, rhs); +} + +HloFlow HloBuilder::Maximum(HloFlow lhs, HloFlow rhs) { + return MakeElementwiseInstruction("maximum", lhs, rhs); +} + +HloFlow HloBuilder::Reshape(HloFlow input, + const pblczero::XlaShapeProto& new_shape) { + if (input->shape().element_type() != new_shape.element_type()) { + throw Exception("Reshape must have the same element type"); + } + size_t old_elements = std::accumulate(input->shape().dimensions().begin(), + input->shape().dimensions().end(), 1, + std::multiplies()); + size_t new_elements = std::accumulate(new_shape.dimensions().begin(), + new_shape.dimensions().end(), 1, + std::multiplies()); + if (old_elements != new_elements) { + throw Exception("Reshape must have the same number of elements: " + + std::to_string(old_elements) + " vs " + + std::to_string(new_elements)); + } + return MakeInstruction("reshape", new_shape, {input}); +} + +HloFlow HloBuilder::Dot(HloFlow lhs, HloFlow rhs, + const pblczero::XlaDotDimensionNumbers& dn) { + pblczero::XlaShapeProto new_shape; + if (lhs->shape().element_type() != rhs->shape().element_type()) { + throw Exception("Dot operands must have the same element type"); + } + new_shape.set_element_type(lhs->shape().element_type()); + if (dn.lhs_batch_dimensions_size() != dn.rhs_batch_dimensions_size()) { + throw Exception("Dot batch dimensions must have the same size"); + } + for (size_t i = 0; i < dn.lhs_batch_dimensions_size(); ++i) { + auto lhs_dim = lhs->shape().dimensions(dn.lhs_batch_dimensions(i)); + auto rhs_dim = rhs->shape().dimensions(dn.rhs_batch_dimensions(i)); + if (lhs_dim != rhs_dim) { + throw Exception("Dot batch dimensions must have the same size"); + } + new_shape.add_dimensions(lhs_dim); + } + if (dn.lhs_contracting_dimensions_size() != + dn.rhs_contracting_dimensions_size()) { + throw Exception("Dot contracting dimensions must have the same size"); + } + for (size_t i = 0; i < dn.lhs_contracting_dimensions_size(); ++i) { + auto lhs_dim = lhs->shape().dimensions(dn.lhs_contracting_dimensions(i)); + auto rhs_dim = rhs->shape().dimensions(dn.rhs_contracting_dimensions(i)); + if (lhs_dim != rhs_dim) { + throw Exception("Dot contracting dimensions must have the same size"); + } + } + // Sorry, github copilot generated the code below (well, above too). Enjoy! + for (size_t i = 0; i < lhs->shape().dimensions_size(); ++i) { + if (std::find(dn.lhs_batch_dimensions().begin(), + dn.lhs_batch_dimensions().end(), + i) == dn.lhs_batch_dimensions().end() && + std::find(dn.lhs_contracting_dimensions().begin(), + dn.lhs_contracting_dimensions().end(), + i) == dn.lhs_contracting_dimensions().end()) { + new_shape.add_dimensions(lhs->shape().dimensions(i)); + } + } + for (size_t i = 0; i < rhs->shape().dimensions_size(); ++i) { + if (std::find(dn.rhs_batch_dimensions().begin(), + dn.rhs_batch_dimensions().end(), + i) == dn.rhs_batch_dimensions().end() && + std::find(dn.rhs_contracting_dimensions().begin(), + dn.rhs_contracting_dimensions().end(), + i) == dn.rhs_contracting_dimensions().end()) { + new_shape.add_dimensions(rhs->shape().dimensions(i)); + } + } + ResetXlaShapeProtoLayout(&new_shape); + auto flow = MakeInstruction("dot", new_shape, {lhs, rhs}); + *flow->mutable_dot_dimension_numbers() = dn; + return flow; +} + +HloFlow HloBuilder::Tanh(HloFlow input) { + return MakeInstruction("tanh", input->shape(), {input}); +} + +HloFlow HloBuilder::Tuple(const std::vector& elements) { + pblczero::XlaShapeProto shape; + shape.set_element_type(pblczero::XlaShapeProto::TUPLE); + for (const auto& element : elements) { + *shape.add_tuple_shapes() = element->shape(); + } + return MakeInstruction("tuple", shape, elements); +} + +namespace { +// Go over all "parameter" instructions of the computation and assign +// "parameter_number" field with increasing numbers. +// Normally it's not requiredm but in our case it's simpler. +// Outputs shapes and instruction names of parameters. +std::pair, std::vector> +AssignParameterIndices(const HloComputation& comp) { + std::vector parameter_shapes; + std::vector parameter_names; + size_t idx = 0; + for (const auto& instr : comp) { + if (instr->opcode() == "parameter") { + instr->set_parameter_number(idx++); + parameter_shapes.push_back(instr->shape()); + parameter_names.push_back(std::string(instr->name())); + } + } + return {parameter_shapes, parameter_names}; +} + +// Finalizes HloComputationProto (sets name, renumbers parameters, adds +// computation shape and root instruction). +pblczero::HloComputationProto MakeComputation(const HloComputation& comp, + std::string_view name, + size_t id) { + pblczero::HloComputationProto ret; + ret.set_id(id); + ret.set_name(name); + auto [shapes, names] = AssignParameterIndices(comp); + for (auto& instr : comp) *ret.add_instructions() = *instr; + *ret.mutable_program_shape()->mutable_parameters() = shapes; + *ret.mutable_program_shape()->mutable_parameter_names() = names; + *ret.mutable_program_shape()->mutable_result() = comp.back()->shape(); + ret.set_root_id(comp.back()->id()); + return ret; +} +} // namespace + +// Assigns unique names to all instructions in the module. +// In StableHLO instructions are allowed to have numeric names, but in XLA HLO +// they are not, so we use "i"+number. +void HloBuilder::AssignInstructionNames() { + // Every instruction in the module should have an unique name, numeric names + // are allowed. + size_t idx = 0; + for (auto& instr : entry_computation_) { + instr->set_name("i" + std::to_string(idx++)); + } + for (auto& [_, comp] : dependent_computations_) { + for (auto& instr : *comp.mutable_instructions()) { + instr.set_name("i" + std::to_string(idx++)); + } + } +} + +pblczero::HloModuleProto HloBuilder::Build(std::string_view name) { + AssignInstructionNames(); + pblczero::HloModuleProto module; + module.set_name(name); + module.set_entry_computation_name("main"); + module.set_entry_computation_id(0); + *module.add_computations() = MakeComputation(entry_computation_, "main", 0); + for (auto& [name, comp] : dependent_computations_) { + *module.add_computations() = comp; + } + *module.mutable_host_program_shape() = module.computations(0).program_shape(); + return module; +} + +void ResetXlaShapeProtoLayout(pblczero::XlaShapeProto* shape) { + shape->mutable_layout()->mutable_minor_to_major()->clear(); + shape->mutable_is_dynamic_dimension()->clear(); + + for (size_t i = 0; i < shape->dimensions_size(); ++i) { + shape->add_is_dynamic_dimension(false); + shape->mutable_layout()->add_minor_to_major(shape->dimensions_size() - i - + 1); + } +} + +} // namespace lczero \ No newline at end of file diff --git a/src/neural/xla/hlo_builder.h b/src/neural/xla/hlo_builder.h new file mode 100644 index 0000000000..b294408843 --- /dev/null +++ b/src/neural/xla/hlo_builder.h @@ -0,0 +1,107 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#include +#include +#include + +#include "neural/xla/hlo.pb.h" +#include "utils/logging.h" + +namespace lczero { + +class HloContext; +class HloBuilder; + +using HloFlow = const pblczero::HloInstructionProto*; +using HloComputation = + std::vector>; + +// A builder class for constructing HloModuleProto. +class HloBuilder { + public: + // HLO operations. + HloFlow Parameter(const pblczero::XlaShapeProto& shape); + HloFlow Constant(const pblczero::XlaLiteralProto& literal); + HloFlow Convert(HloFlow input, const pblczero::XlaShapeProto::Type type); + HloFlow Convolution( + HloFlow input, HloFlow filter, const pblczero::XlaWindow& window, + const pblczero::XlaConvolutionDimensionNumbers& dimension_numbers); + HloFlow Broadcast(HloFlow input, const pblczero::XlaShapeProto& target_shape, + const std::vector& broadcast_dimensions); + HloFlow Add(HloFlow lhs, HloFlow rhs); + HloFlow Maximum(HloFlow lhs, HloFlow rhs); + HloFlow Reshape(HloFlow input, const pblczero::XlaShapeProto& new_shape); + HloFlow Dot(HloFlow lhs, HloFlow rhs, + const pblczero::XlaDotDimensionNumbers& dimension_numbers); + HloFlow Tanh(HloFlow input); + HloFlow Tuple(const std::vector& elements); + + // Build the HloModuleProto with a given name. + pblczero::HloModuleProto Build(std::string_view name); + + private: + pblczero::HloInstructionProto* MakeInstruction( + std::string_view opcode, const pblczero::XlaShapeProto& shape, + const std::vector operands); + pblczero::HloInstructionProto* MakeElementwiseInstruction( + std::string_view opcode, HloFlow lhs, HloFlow rhs); + void AssignInstructionNames(); + + HloComputation entry_computation_; + std::unordered_map + dependent_computations_; + pblczero::XlaOpMetadata metadata_; + friend class HloContext; +}; + +// A context class for annotating parts of the HLO computation with metadata, +// like original ONNX op, its name, and source file name and line. +// The class saves the current metadata in constructor and restores it in +// destructor, making it possible to use it in a scoped way. +class HloContext { + public: + HloContext(HloBuilder* builder) + : builder_(builder), saved_metadata_(builder->metadata_) {} + ~HloContext() { builder_->metadata_ = saved_metadata_; } + void SetOpType(std::string_view op_type) const { + builder_->metadata_.set_op_type(op_type); + } + void SetOpName(std::string_view op_name) const { + builder_->metadata_.set_op_name(op_name); + } + + private: + HloBuilder* builder_; + pblczero::XlaOpMetadata saved_metadata_; +}; + +// A helper function to reset a shape of a layout. Marks all dimensions as +// non-dynamic, and sets layout to major_to_minor. +void ResetXlaShapeProtoLayout(pblczero::XlaShapeProto* shape); + +} // namespace lczero \ No newline at end of file diff --git a/src/neural/xla/network_xla.cc b/src/neural/xla/network_xla.cc new file mode 100644 index 0000000000..37eec13ee9 --- /dev/null +++ b/src/neural/xla/network_xla.cc @@ -0,0 +1,305 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#include + +#include "neural/factory.h" +#include "neural/network.h" +#include "neural/onnx/converter.h" +#include "neural/xla/onnx2hlo.h" +#include "neural/xla/xla_runner.h" +#include "utils/bititer.h" + +namespace lczero { +namespace { + +class Lc0InputTensor : public XlaTensor { + public: + Lc0InputTensor(size_t max_batch_size) + : max_batch_size_(max_batch_size), + // TODO replace with make_unique_for_overwrite() once C++20 is + // available. + data_(new float[GetTensorByteSizeForBatch(max_batch_size)]), + shape_{0, kInputPlanes, 8, 8} {} + + const std::vector& shape() const override { return shape_; } + const void* data() const override { return data_.get(); } + size_t size() const override { return GetTensorByteSizeForBatch(shape_[0]); } + size_t capacity() const override { + return GetTensorByteSizeForBatch(max_batch_size_); + } + pblczero::XlaShapeProto::Type type() const override { + return pblczero::XlaShapeProto::F32; + } + + // Adds a batch to the tensor and returns a pointer to the start of the its + // part in the buffer. Does NOT initialize the data with zeros. + float* AddBatch() { + assert(size_t(shape_[0]) < max_batch_size_); + auto ret = data_.get() + shape_[0] * GetTensorByteSizeForBatch(1); + ++shape_[0]; + return ret; + } + size_t GetBatchSize() const { return shape_[0]; } + + private: + static size_t GetTensorByteSizeForBatch(size_t batch_size) { + return kInputPlanes * 8 * 8 * batch_size * sizeof(float); + } + + const size_t max_batch_size_; + std::unique_ptr data_; + std::vector shape_; +}; + +class XlaNetwork; +class XlaComputation : public NetworkComputation { + public: + XlaComputation(const XlaNetwork* network); + void AddInput(InputPlanes&& input) override; + int GetBatchSize() const override; + void ComputeBlocking() override; + float GetQVal(int sample) const override; + float GetDVal(int sample) const override; + float GetPVal(int sample, int move_id) const override; + float GetMVal(int sample) const override; + + private: + const XlaNetwork* network_; + Lc0InputTensor input_tensor_; + std::vector> outputs_; +}; + +// Indices of various heads in the HLO output. +struct XlaNetworkOptions { + std::optional output_value_idx; + std::optional output_wdl_idx; + std::optional output_policy_idx; + std::optional output_mlh_idx; +}; + +class XlaNetwork : public Network { + public: + XlaNetwork(std::unique_ptr runner, + const XlaNetworkOptions& options, + const pblczero::NetworkFormat& format); + + const NetworkCapabilities& GetCapabilities() const override { + return capabilities_; + } + std::unique_ptr NewComputation() override { + return std::make_unique(this); + } + int GetMiniBatchSize() const override { + // 32 is the default prefetch size, subtract it so that backend doesn't + // crash. + // TODO make it better when we have a proper way to query the batch size. + return runner_->GetMaxBatchSize() - 32; + } + + private: + std::unique_ptr runner_; + XlaNetworkOptions options_; + NetworkCapabilities capabilities_; + + friend class XlaComputation; +}; + +XlaComputation::XlaComputation(const XlaNetwork* network) + : network_(network), input_tensor_(network->runner_->GetMaxBatchSize()) {} + +void XlaComputation::AddInput(InputPlanes&& input) { + float* ptr = input_tensor_.AddBatch(); + memset(ptr, 0, 8 * 8 * kInputPlanes * sizeof(float)); + for (const auto& plane : input) { + for (auto bit : IterateBits(plane.mask)) ptr[bit] = plane.value; + ptr += 8 * 8; + } +} + +float XlaComputation::GetQVal(int sample) const { + if (network_->options_.output_wdl_idx) { + const float* data = reinterpret_cast( + outputs_[*network_->options_.output_wdl_idx]->data()); + return data[sample * 3 + 0] - data[sample * 3 + 2]; + } else { + const float* data = reinterpret_cast( + outputs_[*network_->options_.output_value_idx]->data()); + return data[sample]; + } +} + +float XlaComputation::GetDVal(int sample) const { + if (network_->options_.output_wdl_idx) { + const float* data = reinterpret_cast( + outputs_[*network_->options_.output_wdl_idx]->data()); + return data[sample * 3 + 1]; + } + return 0.0f; +} + +float XlaComputation::GetPVal(int sample, int move_id) const { + const float* data = reinterpret_cast( + outputs_[*network_->options_.output_policy_idx]->data()); + return data[sample * 1858 + move_id]; +} + +float XlaComputation::GetMVal(int sample) const { + if (network_->options_.output_mlh_idx) { + const float* data = reinterpret_cast( + outputs_[*network_->options_.output_mlh_idx]->data()); + return data[sample]; + } + return 0.0f; +} + +int XlaComputation::GetBatchSize() const { + return input_tensor_.GetBatchSize(); +} + +void XlaComputation::ComputeBlocking() { + outputs_ = network_->runner_->ExecuteBlocking({&input_tensor_}); +} + +XlaNetwork::XlaNetwork(std::unique_ptr runner, + const XlaNetworkOptions& options, + const pblczero::NetworkFormat& format) + : runner_(std::move(runner)), + options_(options), + capabilities_{format.input(), format.moves_left()} {} + +// Converts ONNX model to HLO (for various batch sizes) and adds them to the +// XlaRunner. +XlaNetworkOptions FillXlaRunnerFromOnnx(const pblczero::OnnxModel& onnx_model, + XlaRunner* runner, + size_t max_batch_size, size_t steps) { + pblczero::ModelProto onnx; + onnx.ParseFromString(onnx_model.model()); + + std::unordered_map constant_to_parameter_idx; + std::unordered_map input_to_parameter_idx; + std::unordered_map output_to_parameter_idx; + + auto add_tensors = [](const std::vector& tensors, + std::unordered_map& map) { + for (const auto& tensor : tensors) { + auto iter = map.find(tensor.name); + if (iter == map.end()) { + map[tensor.name] = tensor.param_idx; + } else if (iter->second != tensor.param_idx) { + throw Exception("Inconsistent index for " + tensor.name); + } + } + }; + + for (size_t i = 0; i < steps; ++i) { + size_t batch_size = max_batch_size * (i + 1) / steps; + CERR << "Building HLO for batch size " << batch_size << "..."; + auto conversion = ConvertOnnxToHlo(onnx, batch_size, {}); + add_tensors(conversion.constants, constant_to_parameter_idx); + add_tensors(conversion.inputs, input_to_parameter_idx); + add_tensors(conversion.outputs, output_to_parameter_idx); + runner->AddModule(batch_size, conversion.hlo_module); + } + + std::vector> constants; + constants.resize(constant_to_parameter_idx.size() + + input_to_parameter_idx.size()); + for (const auto& initializer : onnx.graph().initializer()) { + auto iter = constant_to_parameter_idx.find(std::string(initializer.name())); + if (iter == constant_to_parameter_idx.end()) continue; + auto idx = iter->second; + assert(idx < constants.size()); + constants[idx] = OnnxTensorToXlaTensor(initializer); + } + + CERR << "Transferring constants..."; + runner->SetFrozenInputs(std::move(constants)); + CERR << "Done."; + + XlaNetworkOptions options; + if (input_to_parameter_idx.size() != 1 || + input_to_parameter_idx.begin()->first != onnx_model.input_planes()) { + throw Exception("Expected a single input named " + + std::string(onnx_model.input_planes())); + } + if (onnx_model.has_output_value()) { + options.output_value_idx = + output_to_parameter_idx.at(std::string(onnx_model.output_value())); + } + if (onnx_model.has_output_wdl()) { + options.output_wdl_idx = + output_to_parameter_idx.at(std::string(onnx_model.output_wdl())); + } + if (onnx_model.has_output_policy()) { + options.output_policy_idx = + output_to_parameter_idx.at(std::string(onnx_model.output_policy())); + } + if (onnx_model.has_output_mlh()) { + options.output_mlh_idx = + output_to_parameter_idx.at(std::string(onnx_model.output_mlh())); + } + return options; +} + +// Makes an XLA network. First converts the weights to ONNX, and then calls +// FillXlaRunnerFromOnnx to convert them further to HLO and them compile them. +std::unique_ptr MakeXlaNetwork(const std::optional& w, + const OptionsDict& opts) { + if (!w) throw Exception("The XLA backend requires a network file."); + int device = opts.GetOrDefault("device", 0); + // Note: if the plugin_path does NOT contain a slash, it's looked up in the + // LD_LIBRARY_PATH (and a few other system defined places). If it does contain + // a slash, it's looked up at the exact relative or absolute path. + auto runner = std::make_unique( + opts.GetOrDefault("plugin_path", + "./pjrt_c_api_gpu_plugin.so") + .c_str(), + device); + int max_batch_size = opts.GetOrDefault("max_batch", 739); + int steps = opts.GetOrDefault("steps", 13); + + XlaNetworkOptions options; + if (w->has_onnx_model()) { + options = FillXlaRunnerFromOnnx(w->onnx_model(), runner.get(), + max_batch_size, steps); + } else { + CERR << "Converting weights to ONNX first."; + WeightsToOnnxConverterOptions onnx_converter_options; + auto converted = ConvertWeightsToOnnx(*w, onnx_converter_options); + options = FillXlaRunnerFromOnnx(converted.onnx_model(), runner.get(), + max_batch_size, steps); + } + + return std::make_unique(std::move(runner), options, + w->format().network_format()); +} + +REGISTER_NETWORK("xla", MakeXlaNetwork, -34) + +} // namespace +} // namespace lczero diff --git a/src/neural/xla/onnx2hlo.cc b/src/neural/xla/onnx2hlo.cc new file mode 100644 index 0000000000..b41906e71d --- /dev/null +++ b/src/neural/xla/onnx2hlo.cc @@ -0,0 +1,532 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#include "neural/xla/onnx2hlo.h" + +#include + +#include "neural/onnx/onnx.pb.h" +#include "neural/xla/hlo.pb.h" +#include "neural/xla/hlo_builder.h" +#include "neural/xla/print_hlo.h" +#include "utils/exception.h" + +namespace lczero { +namespace { + +pblczero::XlaShapeProto::Type OnnxTypeToXlaType( + const pblczero::TensorProto::DataType& type) { + switch (type) { + case pblczero::TensorProto::FLOAT: + return pblczero::XlaShapeProto::F32; + case pblczero::TensorProto::UINT8: + return pblczero::XlaShapeProto::U8; + case pblczero::TensorProto::INT8: + return pblczero::XlaShapeProto::S8; + case pblczero::TensorProto::UINT16: + return pblczero::XlaShapeProto::U16; + case pblczero::TensorProto::INT16: + return pblczero::XlaShapeProto::S16; + case pblczero::TensorProto::INT32: + return pblczero::XlaShapeProto::S32; + case pblczero::TensorProto::INT64: + return pblczero::XlaShapeProto::S64; + case pblczero::TensorProto::BOOL: + return pblczero::XlaShapeProto::PRED; + case pblczero::TensorProto::FLOAT16: + return pblczero::XlaShapeProto::F16; + case pblczero::TensorProto::DOUBLE: + return pblczero::XlaShapeProto::F64; + case pblczero::TensorProto::UINT32: + return pblczero::XlaShapeProto::U32; + case pblczero::TensorProto::UINT64: + return pblczero::XlaShapeProto::U64; + case pblczero::TensorProto::COMPLEX64: + return pblczero::XlaShapeProto::C64; + case pblczero::TensorProto::COMPLEX128: + return pblczero::XlaShapeProto::C128; + case pblczero::TensorProto::BFLOAT16: + return pblczero::XlaShapeProto::BF16; + case pblczero::TensorProto::STRING: + default: + throw Exception("Unsupported ONNX type " + + pblczero::TensorProto::DataType_Name(type)); + } +} + +// Converts an ONNX shape to an XLA shape, replacing the batch dimension with +// the provided batch size. +pblczero::XlaShapeProto OnnxShapeToXlaShape(const pblczero::TypeProto& type, + std::optional batch_size) { + pblczero::XlaShapeProto shape; + shape.set_element_type(OnnxTypeToXlaType(type.tensor_type().elem_type())); + for (const auto& dim : type.tensor_type().shape().dim()) { + if (dim.has_dim_value()) { + shape.add_dimensions(dim.dim_value()); + continue; + } + if (dim.dim_param() == "batch") { + if (batch_size.has_value()) { + shape.add_dimensions(batch_size.value()); + continue; + } + throw Exception("Batch size not provided"); + } + throw Exception("Unsupported dimension type " + type.OutputAsJson()); + } + ResetXlaShapeProtoLayout(&shape); + return shape; +} + +// Type is not a field of the ONNX tensor, so this function extracts the shape +// and converts it (discarding the data). +pblczero::XlaShapeProto OnnxTensorToXlaShape( + const pblczero::TensorProto& tensor) { + pblczero::TypeProto type; + type.mutable_tensor_type()->set_elem_type(tensor.data_type()); + for (const auto& dim : tensor.dims()) { + type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } + return OnnxShapeToXlaShape(type, std::nullopt); +} + +// Converts an ONNX tensor to an XLA literal (which is a shape and a data). +pblczero::XlaLiteralProto OnnxTensorToXlaLiteral( + const pblczero::TensorProto& tensor) { + pblczero::XlaLiteralProto literal; + *literal.mutable_shape() = OnnxTensorToXlaShape(tensor); + + auto convert = [&](std::string_view src, /*std::vector*/ auto* dst) { + using value_type = + typename std::remove_pointer::type::value_type; + dst->assign(reinterpret_cast(src.data()), + reinterpret_cast(src.data() + src.size())); + }; + + switch (tensor.data_type()) { + case pblczero::TensorProto::FLOAT: + convert(tensor.raw_data(), literal.mutable_f32s()); + break; + case pblczero::TensorProto::INT64: + convert(tensor.raw_data(), literal.mutable_s64s()); + break; + default: + throw Exception("Cannot convert ONNX tensor to XLA literal for type " + + pblczero::XlaShapeProto::Type_Name( + OnnxTypeToXlaType(tensor.data_type()))); + } + return literal; +} + +class Onnx2HloConverter { + public: + Onnx2HloConverter(const Onnx2HloOptions& options) : options_(options) { + onnx_op_to_builder_["Add"] = &Onnx2HloConverter::OpAdd; + onnx_op_to_builder_["Conv"] = &Onnx2HloConverter::OpConv; + onnx_op_to_builder_["MatMul"] = &Onnx2HloConverter::OpMatMul; + onnx_op_to_builder_["Relu"] = &Onnx2HloConverter::OpRelu; + onnx_op_to_builder_["Reshape"] = &Onnx2HloConverter::OpReshape; + onnx_op_to_builder_["Tanh"] = &Onnx2HloConverter::OpTanh; + } + + Onnx2HloResult Convert(const pblczero::ModelProto& onnx_model, + size_t minibatch_size) { + batch_size_ = minibatch_size; + // Populate the set of ONNX initializers (constants), but not emit them for + // now. They are emitted lazily so that they appear close to the first use. + BuildInitializerMapping(onnx_model); + // Convert ONNX inputs to HLO parameters. + BuildInputs(onnx_model.graph().input()); + BuildGraph(onnx_model.graph()); + Onnx2HloResult result; + // Convert ONNX outputs to HLO result. + result.outputs = BuildOutputs(onnx_model.graph().output()); + result.hlo_module = builder_.Build("onnx_model"); + for (size_t i = 0; i < params_.size(); ++i) { + const auto& param = params_[i]; + auto& dst = param.is_constant ? result.constants : result.inputs; + dst.push_back({i, param.name, param.flow->shape()}); + } + // PrettyPrintHlo(result.hlo_module, {}, std::cout); + return result; + } + + private: + std::vector BuildOutputs( + const std::vector& graph_output) { + // Gathers outputs into the root tuple, optionally converting their type if + // I/O type is different from the instruction output. + std::vector result; + std::vector outputs; + for (size_t i = 0; i < graph_output.size(); ++i) { + const auto& output = graph_output[i]; + auto flow = GetFlowByName(std::string(output.name())); + if (flow->shape().element_type() != options_.io_type) { + auto ctx = HloContext(&builder_); + ctx.SetOpType("output"); + ctx.SetOpName(output.name()); + flow = builder_.Convert(flow, options_.io_type); + } + result.push_back({i, std::string(output.name()), flow->shape()}); + outputs.push_back(flow); + } + builder_.Tuple(outputs); + return result; + } + + void BuildInitializerMapping(const pblczero::ModelProto& onnx_model) { + for (const auto& tensor : onnx_model.graph().initializer()) { + initializers_[std::string(tensor.name())] = &tensor; + } + } + + // Checks that the ONNX node doesn't have any unknown attributes. + void CheckKnownAttributes( + const pblczero::NodeProto& node, + const std::initializer_list attributes) { + for (const auto& attribute : node.attribute()) { + if (std::find(attributes.begin(), attributes.end(), attribute.name()) == + attributes.end()) { + throw Exception("Unknown attribute " + std::string(attribute.name())); + } + } + } + + // Fetches an HloFlow by name. If the name is not in the map, check whether + // there is an initializer for it, and either create a constant or a parameter + // depending on its size. + HloFlow GetFlowByName(const std::string& name) { + auto iter = onnx_name_to_hlo_flow_.find(name); + if (iter != onnx_name_to_hlo_flow_.end()) return iter->second; + + auto iter2 = initializers_.find(name); + if (iter2 == initializers_.end()) { + throw Exception("Unknown input " + name); + } + auto ctx = HloContext(&builder_); + ctx.SetOpType("initializer"); + ctx.SetOpName(name); + + HloFlow flow = nullptr; + if (iter2->second->raw_data().size() <= options_.max_inline_constant_size) { + flow = builder_.Constant(OnnxTensorToXlaLiteral(*iter2->second)); + } else { + const auto shape = OnnxTensorToXlaShape(*iter2->second); + flow = MakeParameter(name, shape, true); + } + onnx_name_to_hlo_flow_[name] = flow; + return flow; + } + + // A helper function to fetch an input of ONNX node by index. + HloFlow GetInput(const pblczero::NodeProto& node, size_t idx, + bool optional = false) { + if (idx >= node.input_size()) { + if (optional) return nullptr; + throw Exception("Input " + std::to_string(idx) + " not set"); + } + return GetFlowByName(std::string(node.input(idx))); + } + + // A helper function to fetch an attribute of ONNX node by name. + const pblczero::AttributeProto* GetAttribute(const pblczero::NodeProto& node, + std::string_view name, + bool optional = false) { + for (const auto& attribute : node.attribute()) { + if (attribute.name() == name) return &attribute; + } + if (optional) return nullptr; + throw Exception("Attribute " + std::string(name) + " not set"); + } + + ///////////////////////////////////////////////////////////////////////////// + // ONNX operations + ///////////////////////////////////////////////////////////////////////////// + + std::vector OpConv(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, {"pads", "kernel_shape"}); + auto* input = GetInput(node, 0); + auto* kernel = GetInput(node, 1); + auto* bias = GetInput(node, 2, true); + + pblczero::XlaConvolutionDimensionNumbers dn; + dn.set_input_batch_dimension(0); + dn.set_input_feature_dimension(1); + dn.set_kernel_input_feature_dimension(1); + dn.set_kernel_output_feature_dimension(0); + dn.set_output_batch_dimension(0); + dn.set_output_feature_dimension(1); + const size_t num_dims = input->shape().dimensions_size() - 2; + for (size_t i = 0; i < num_dims; ++i) { + dn.add_input_spatial_dimensions(i + 2); + dn.add_kernel_spatial_dimensions(i + 2); + dn.add_output_spatial_dimensions(i + 2); + } + + const auto* pads = GetAttribute(node, "pads"); + const auto* kernel_shape = GetAttribute(node, "kernel_shape"); + if (!pads || pads->ints_size() != 2 * num_dims) { + throw Exception("'pads' attribute not set or wrong size"); + } + if (!kernel_shape || kernel_shape->ints_size() != num_dims) { + throw Exception("'kernel_shape' attribute not set or wrong size"); + } + pblczero::XlaWindow window; + for (size_t i = 0; i < input->shape().dimensions_size() - 2; ++i) { + auto* dim = window.add_dimensions(); + dim->set_size(kernel_shape->ints(i)); + dim->set_stride(1); + dim->set_padding_low(pads->ints(i)); + dim->set_padding_high(pads->ints(i + num_dims)); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* conv = builder_.Convolution(input, kernel, window, dn); + + if (!bias) return {conv}; + auto* flow = builder_.Broadcast(bias, conv->shape(), {1}); + return {builder_.Add(conv, flow)}; + } + + std::vector OpRelu(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, {}); + auto* input = GetInput(node, 0); + auto* zero = MakeScalar(0, input->shape().element_type()); + zero = builder_.Broadcast(zero, input->shape(), {}); + return {builder_.Maximum(input, zero)}; + } + + std::vector OpTanh(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, {}); + auto* input = GetInput(node, 0); + return {builder_.Tanh(input)}; + } + + std::vector OpAdd(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, {}); + auto* lhs = GetInput(node, 0); + auto* rhs = GetInput(node, 1); + std::tie(lhs, rhs) = EqualizeShape(lhs, rhs); + return {builder_.Add(lhs, rhs)}; + } + + std::vector OpReshape(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, {}); + auto* input = GetInput(node, 0); + if (node.input_size() < 2) { + throw Exception("Reshape requires a shape input"); + } + auto dims_tensor = initializers_.find(std::string(node.input(1))); + if (dims_tensor == initializers_.end()) { + throw Exception("Reshape only supports constant shape"); + } + auto new_dims = OnnxTensorToXlaLiteral(*dims_tensor->second).s64s(); + pblczero::XlaShapeProto new_shape; + new_shape.set_element_type(input->shape().element_type()); + for (size_t i = 0; i < new_dims.size(); ++i) { + auto dim = new_dims[i]; + if (dim == -1) dim = batch_size_; + if (dim == 0) { + if (new_dims.size() != input->shape().dimensions_size()) { + throw Exception("Reshape cannot infer shape when rank changes"); + } + dim = input->shape().dimensions(i); + } + new_shape.add_dimensions(dim); + } + ResetXlaShapeProtoLayout(&new_shape); + return {builder_.Reshape(input, new_shape)}; + } + + std::vector OpMatMul(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, {}); + auto* lhs = GetInput(node, 0); + auto* rhs = GetInput(node, 1); + if (lhs->shape().dimensions_size() != 2 || + rhs->shape().dimensions_size() != 2) { + throw Exception("MatMul only implemented for 2D inputs so far"); + } + pblczero::XlaDotDimensionNumbers dn; + dn.add_lhs_contracting_dimensions(1); + dn.add_rhs_contracting_dimensions(0); + return {builder_.Dot(lhs, rhs, dn)}; + } + + ///////////////////////////////////////////////////////////////////////////// + + // Makes a scalar constant (usually 0 or 1) of a given type. + template + HloFlow MakeScalar(T value, pblczero::XlaShapeProto::Type type) { + pblczero::XlaLiteralProto literal; + literal.mutable_shape()->set_element_type(type); + literal.mutable_shape()->mutable_layout(); + switch (type) { + case pblczero::XlaShapeProto::F32: + literal.add_f32s(value); + break; + default: + throw Exception("Unsupported type for zero constant"); + } + return builder_.Constant(literal); + } + + // Take two inputs and optionally performs numpy-style broadcasting to make + // them equal shape. + std::pair EqualizeShape(HloFlow lhs, HloFlow rhs) { + const auto& lhs_dims = lhs->shape().dimensions(); + const auto& rhs_dims = rhs->shape().dimensions(); + + const size_t num_dims = std::max(lhs_dims.size(), rhs_dims.size()); + std::vector output_dims(num_dims); + std::vector lhs_broadcast_dims; + std::vector rhs_broadcast_dims; + bool lhs_broadcast = lhs_dims.size() < num_dims; + bool rhs_broadcast = rhs_dims.size() < num_dims; + + for (size_t i = 0; i < num_dims; ++i) { + int lhs_idx = i + lhs_dims.size() - num_dims; + int rhs_idx = i + rhs_dims.size() - num_dims; + const auto lhs_dim = (lhs_idx < 0) ? 1 : lhs_dims[lhs_idx]; + const auto rhs_dim = (rhs_idx < 0) ? 1 : rhs_dims[rhs_idx]; + if (lhs_dim != rhs_dim) { + if (lhs_dim != 1 && rhs_dim != 1) { + throw Exception("Incompatible shapes for broadcast"); + } + if (lhs_dim == 1) lhs_broadcast = true; + if (rhs_dim == 1) rhs_broadcast = true; + } + if (lhs_idx >= 0) lhs_broadcast_dims.push_back(i); + if (rhs_idx >= 0) rhs_broadcast_dims.push_back(i); + } + + if (lhs_broadcast) { + lhs = builder_.Broadcast(lhs, rhs->shape(), lhs_broadcast_dims); + } + if (rhs_broadcast) { + rhs = builder_.Broadcast(rhs, lhs->shape(), rhs_broadcast_dims); + } + return {lhs, rhs}; + } + + // Convert ONNX inputs to HLO parameters. + void BuildInputs(const std::vector& inputs) { + for (const auto& input : inputs) { + auto ctx = HloContext(&builder_); + ctx.SetOpType("input"); + ctx.SetOpName(input.name()); + auto out_shape = OnnxShapeToXlaShape(input.type(), batch_size_); + auto in_shape = out_shape; + in_shape.set_element_type(options_.io_type); + const auto* flow = + MakeParameter(std::string(input.name()), in_shape, false); + flow = builder_.Convert(flow, out_shape.element_type()); + onnx_name_to_hlo_flow_[std::string(input.name())] = flow; + } + } + + // Makes a parameter instruction (for inputs or large constants). + HloFlow MakeParameter(const std::string& name, + const pblczero::XlaShapeProto& shape, + bool is_constant) { + auto* res = builder_.Parameter(shape); + params_.push_back({name, res, is_constant}); + return res; + } + + void BuildGraph(const pblczero::GraphProto& graph) { + for (const auto& node : graph.node()) { + // Set up the context so that nodes have metadata from the original ONNX. + auto ctx = HloContext(&builder_); + ctx.SetOpType(node.op_type()); + ctx.SetOpName(node.name()); + DispatchNode(node); + } + } + + // Calls the correct function to handle the ONNX node, and stores output in + // the map. + void DispatchNode(const pblczero::NodeProto& node) { + auto iter = onnx_op_to_builder_.find(std::string(node.op_type())); + if (iter == onnx_op_to_builder_.end()) { + throw Exception("Unsupported ONNX op[" + std::string(node.op_type()) + + "] name=[" + std::string(node.name()) + "]"); + } + try { + auto outputs = (this->*iter->second)(node); + if (outputs.size() != node.output_size()) { + throw Exception("Node produced wrong number of outputs"); + } + for (size_t i = 0; i < outputs.size(); ++i) { + onnx_name_to_hlo_flow_[std::string(node.output(i))] = outputs[i]; + } + } catch (Exception& e) { + throw Exception("Error in ONNX op[" + std::string(node.op_type()) + + "] name=[" + std::string(node.name()) + "]: " + e.what()); + } + } + + std::unordered_map onnx_name_to_hlo_flow_; + std::unordered_map (Onnx2HloConverter::*)( + const pblczero::NodeProto&)> + onnx_op_to_builder_; + std::unordered_map initializers_; + HloBuilder builder_; + size_t batch_size_ = 0; + Onnx2HloOptions options_; + struct Param { + std::string name; + HloFlow flow; + bool is_constant; + }; + std::vector params_; +}; + +} // namespace + +Onnx2HloResult ConvertOnnxToHlo(const pblczero::ModelProto& onnx_model, + size_t minibatch_size, + const Onnx2HloOptions& options) { + Onnx2HloConverter converter(options); + return converter.Convert(onnx_model, minibatch_size); +} + +std::unique_ptr OnnxTensorToXlaTensor( + const pblczero::TensorProto& onnx_tensor) { + switch (onnx_tensor.data_type()) { + case pblczero::TensorProto::FLOAT: + return std::make_unique(onnx_tensor.dims(), + onnx_tensor.raw_data(), + pblczero::XlaShapeProto::F32); + default: + throw Exception( + "Unsupported ONNX tensor type for buffer conversion " + + pblczero::TensorProto::DataType_Name(onnx_tensor.data_type())); + } +} + +} // namespace lczero \ No newline at end of file diff --git a/src/neural/xla/onnx2hlo.h b/src/neural/xla/onnx2hlo.h new file mode 100644 index 0000000000..284f08d1bc --- /dev/null +++ b/src/neural/xla/onnx2hlo.h @@ -0,0 +1,73 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#pragma once + +#include +#include + +#include "neural/onnx/onnx.pb.h" +#include "neural/xla/hlo.pb.h" +#include "neural/xla/xla_runner.h" + +namespace lczero { + +struct Onnx2HloOptions { + // Constants larger that this size in bytes will be passed as parameters + // instead. This allows them to be shared between different modules. + size_t max_inline_constant_size = 1024; + // The types of input/output tensors (does not affect constants passed as + // parameters). + pblczero::XlaShapeProto::Type io_type = pblczero::XlaShapeProto::F32; +}; + +struct Onnx2HloResult { + struct NamedTensor { + // Index of the tensor in the input or output tuple. + size_t param_idx; + // Name of the tensor from the ONNX model. + std::string name; + pblczero::XlaShapeProto shape; + }; + // Constants that are passed as inputs to the module. + std::vector constants; + std::vector inputs; + std::vector outputs; + pblczero::HloModuleProto hlo_module; +}; + +// Converts an ONNX model to an HLO module. +Onnx2HloResult ConvertOnnxToHlo(const pblczero::ModelProto& onnx_model, + size_t minibatch_size, + const Onnx2HloOptions& options); + +// Converts an ONNX tensor to an XLA tensor (thanks GitHub Copilot for the +// comment suggestion). +std::unique_ptr OnnxTensorToXlaTensor( + const pblczero::TensorProto& onnx_tensor); + +} // namespace lczero diff --git a/src/neural/xla/pjrt.cc b/src/neural/xla/pjrt.cc new file mode 100644 index 0000000000..745a2d40f3 --- /dev/null +++ b/src/neural/xla/pjrt.cc @@ -0,0 +1,403 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#include "pjrt.h" + +#include + +#include +#include +#include +#include + +#include "pjrt_c_api.h" +#include "utils/logging.h" + +namespace lczero { + +namespace { +static std::string value_to_string(const std::string& value) { return value; } +static std::string value_to_string(int64_t value) { + return std::to_string(value); +} +static std::string value_to_string(const std::vector& value) { + std::string result; + for (auto v : value) { + if (!result.empty()) result += ", "; + result += std::to_string(v); + } + return result; +} +static std::string value_to_string(float value) { + return std::to_string(value); +} +static std::string value_to_string(bool value) { + return value ? "true" : "false"; +} + +template +T MakeStruct() { + T t; + memset(&t, 0, sizeof(t)); + t.struct_size = sizeof(t); + return t; +} + +PJRT_Error_Code GetErrorCode(const PJRT_Api* api, PJRT_Error* error) { + auto args = MakeStruct(); + args.error = error; + api->PJRT_Error_GetCode(&args); + return args.code; +} + +} // namespace + +std::string PjrtKeyValue::value_as_string() const { + return std::visit([&](const auto& arg) { return value_to_string(arg); }, + value_); +} + +PjrtKeyValue MakeKeyValue(const PJRT_NamedValue* kv) { + PjrtKeyValue result; + result.set_key({kv->name, kv->name_size}); + switch (kv->type) { + case PJRT_NamedValue_kString: + result.set_value(std::string(kv->string_value, kv->value_size)); + break; + case PJRT_NamedValue_kInt64: + result.set_value(kv->int64_value); + break; + case PJRT_NamedValue_kInt64List: + result.set_value(std::vector( + kv->int64_array_value, kv->int64_array_value + kv->value_size)); + break; + case PJRT_NamedValue_kFloat: + result.set_value(kv->float_value); + break; + case PJRT_NamedValue_kBool: + result.set_value(kv->bool_value); + break; + } + return result; +} + +std::string PjrtCommon::GetErrorMessage(PJRT_Error* error) const { + auto args = MakeStruct(); + args.error = error; + api_->PJRT_Error_Message(&args); + return std::string(args.message, args.message_size); +} +void PjrtCommon::DestroyErrorMessage(PJRT_Error* error) const { + assert(error); + auto args = MakeStruct(); + args.error = error; + api_->PJRT_Error_Destroy(&args); +} + +void PjrtCommon::CheckError(PJRT_Error* error) const { + if (!error) return; + PjrtException exception(static_cast(GetErrorCode(api_, error)), + GetErrorMessage(error)); + DestroyErrorMessage(error); + throw exception; +} + +PjrtExecutable::PjrtExecutable(const PJRT_Api* api, + PJRT_LoadedExecutable* executable) + : PjrtCommon(api), executable_(executable) { + auto args = MakeStruct(); + args.loaded_executable = executable_; + CheckError(api_->PJRT_LoadedExecutable_GetExecutable(&args)); + + auto args2 = MakeStruct(); + args2.executable = args.executable; + CheckError(api_->PJRT_Executable_NumOutputs(&args2)); + num_outputs_ = args2.num_outputs; +} + +PjrtExecutable::~PjrtExecutable() { + auto args = MakeStruct(); + args.executable = executable_; + CheckError(api_->PJRT_LoadedExecutable_Destroy(&args)); +} + +size_t PjrtExecutable::GetNumOutputs() const { return num_outputs_; } + +std::vector> PjrtExecutable::ExecuteBlocking( + const std::vector& inputs) { + auto options = MakeStruct(); + options.num_non_donatable_input_indices = inputs.size(); + std::vector non_donatable_indices(inputs.size()); + // TODO the buffer 0 is actually donatable. + std::iota(non_donatable_indices.begin(), non_donatable_indices.end(), 0); + options.non_donatable_input_indices = non_donatable_indices.data(); + + auto args = MakeStruct(); + args.executable = executable_; + args.options = &options; + args.num_devices = 1; + std::vector buffers(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) buffers[i] = inputs[i]->buffer_; + PJRT_Buffer* const* buffers_ptr = buffers.data(); + args.num_args = inputs.size(); + args.argument_lists = &buffers_ptr; + + std::vector outputs(num_outputs_); + PJRT_Buffer** outputs_ptr = outputs.data(); + PJRT_Event* event_ptr; + args.output_lists = &outputs_ptr; + args.device_complete_events = &event_ptr; + CheckError(api_->PJRT_LoadedExecutable_Execute(&args)); + + PjrtEvent event(api_, event_ptr); + event.Await(); + + std::vector> output_buffers; + output_buffers.reserve(num_outputs_); + for (size_t i = 0; i < num_outputs_; ++i) { + output_buffers.push_back( + std::make_unique(api_, outputs[i])); + } + return output_buffers; +} + +PjrtDevice::PjrtDevice(const PJRT_Api* api, PJRT_Device* device) + : PjrtCommon(api), device_(device) { + auto args = MakeStruct(); + args.device = device_; + CheckError(api_->PJRT_Device_GetDescription(&args)); + description_ = args.device_description; +} + +std::string PjrtDevice::ToString() const { + auto args = MakeStruct(); + args.device_description = description_; + CheckError(api_->PJRT_DeviceDescription_ToString(&args)); + return {args.to_string, args.to_string_size}; +} + +PjrtClient::PjrtClient(const PJRT_Api* api, PJRT_Client* client) + : PjrtCommon(api), client_(client) {} + +PjrtClient::~PjrtClient() { + auto args = MakeStruct(); + args.client = client_; + CheckError(api_->PJRT_Client_Destroy(&args)); +} + +std::unique_ptr PjrtClient::CompileHlo( + std::string_view hlo, std::string_view config) { + constexpr std::string_view kFormat = "hlo"; + auto program = MakeStruct(); + program.code = const_cast(hlo.data()); + program.code_size = hlo.size(); + program.format = kFormat.data(); + program.format_size = kFormat.size(); + + auto args = MakeStruct(); + args.client = client_; + args.program = &program; + args.compile_options = const_cast(config.data()); + args.compile_options_size = config.size(); + CheckError(api_->PJRT_Client_Compile(&args)); + return std::make_unique(api_, args.executable); +} + +std::vector> PjrtClient::GetDevices() { + auto args = MakeStruct(); + args.client = client_; + CheckError(api_->PJRT_Client_Devices(&args)); + std::vector> result; + result.reserve(args.num_devices); + for (size_t i = 0; i < args.num_devices; ++i) { + result.push_back(std::make_unique(api_, args.devices[i])); + } + return result; +} + +PjrtEvent::PjrtEvent(const PJRT_Api* api, PJRT_Event* event) + : PjrtCommon(api), event_(event) {} + +PjrtEvent::~PjrtEvent() { + auto args = MakeStruct(); + args.event = event_; + CheckError(api_->PJRT_Event_Destroy(&args)); +} + +void PjrtEvent::Await() { + auto args = MakeStruct(); + args.event = event_; + CheckError(api_->PJRT_Event_Await(&args)); +} + +PjrtDeviceBuffer::PjrtDeviceBuffer(const PJRT_Api* api, PJRT_Buffer* buffer) + : PjrtCommon(api), buffer_(buffer) {} + +PjrtDeviceBuffer::~PjrtDeviceBuffer() { + auto args = MakeStruct(); + args.buffer = buffer_; + CheckError(api_->PJRT_Buffer_Destroy(&args)); +} + +size_t PjrtDeviceBuffer::GetSize() const { + auto args = MakeStruct(); + args.src = buffer_; + CheckError(api_->PJRT_Buffer_ToHostBuffer(&args)); + return args.dst_size; +} + +PjrtType PjrtDeviceBuffer::GetType() const { + auto args = MakeStruct(); + args.buffer = buffer_; + CheckError(api_->PJRT_Buffer_ElementType(&args)); + return static_cast(args.type); +} + +std::vector PjrtDeviceBuffer::GetDimensions() const { + auto args = MakeStruct(); + args.buffer = buffer_; + CheckError(api_->PJRT_Buffer_Dimensions(&args)); + return {args.dims, args.dims + args.num_dims}; +} + +std::unique_ptr PjrtDeviceBuffer::DeviceToHost(void* dst, + size_t size) { + auto args = MakeStruct(); + args.src = buffer_; + args.dst = dst; + args.dst_size = size; + CheckError(api_->PJRT_Buffer_ToHostBuffer(&args)); + return std::make_unique(api_, args.event); +} + +PjrtHostToDeviceTransfer::PjrtHostToDeviceTransfer( + const PJRT_Api* api, PJRT_Buffer* buffer, std::unique_ptr event) + : PjrtCommon(api), buffer_(buffer), event_(std::move(event)) {} + +void PjrtHostToDeviceTransfer::Await() { event_->Await(); } + +std::unique_ptr +PjrtHostToDeviceTransfer::AwaitAndReleaseBuffer() { + if (!buffer_) { + throw PjrtException(PjrtErrorCode::INVALID_ARGUMENT, + "Buffer already released"); + } + Await(); + auto res = std::make_unique(api_, buffer_); + buffer_ = nullptr; + return res; +} + +PjrtHostToDeviceTransfer::~PjrtHostToDeviceTransfer() { + Await(); + if (buffer_) { + auto args = MakeStruct(); + args.buffer = buffer_; + CheckError(api_->PJRT_Buffer_Destroy(&args)); + } +} + +std::unique_ptr PjrtClient::HostToDevice( + std::string_view buffer, PjrtType type, const std::vector& dims, + const PjrtDevice* device) { + auto args = MakeStruct(); + args.client = client_; + args.data = buffer.data(); + args.type = static_cast(type); + args.dims = dims.data(); + args.num_dims = dims.size(); + args.host_buffer_semantics = + PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes; + args.device = device->device_; + CheckError(api_->PJRT_Client_BufferFromHostBuffer(&args)); + auto event = std::make_unique(api_, args.done_with_host_buffer); + return std::make_unique(api_, args.buffer, + std::move(event)); +} + +Pjrt::Pjrt(const char* library_path) : PjrtCommon(nullptr) { + // TODO factor out the dlopen/dlsym code into a separate function, and + // implement for other OSes. + void* handle = dlopen(library_path, RTLD_LAZY); + if (!handle) { + throw PjrtException(PjrtErrorCode::INVALID_ARGUMENT, + "Unable to load PJRT library " + + std::string(library_path) + ": " + dlerror()); + } + typedef const PJRT_Api* (*PjrtApiFunc)(); + auto func = reinterpret_cast(dlsym(handle, "GetPjrtApi")); + if (!func) { + throw PjrtException(PjrtErrorCode::INVALID_ARGUMENT, + "Unable to find GetPjrtApi() in PJRT library " + + std::string(library_path) + ": " + dlerror()); + } + api_ = func(); + if (!api_) { + throw PjrtException(PjrtErrorCode::INVALID_ARGUMENT, + "GetPjrtApi() returned nullptr in PJRT library " + + std::string(library_path)); + } + auto [major, minor] = ApiVersion(); + if (major != PJRT_API_MAJOR || minor < PJRT_API_MINOR) { + throw PjrtException( + PjrtErrorCode::INVALID_ARGUMENT, + "PJRT library " + std::string(library_path) + + " has incompatible API version: " + std::to_string(major) + "." + + std::to_string(minor) + " vs " + std::to_string(PJRT_API_MAJOR) + + "." + std::to_string(PJRT_API_MINOR)); + } + Initialize(); +} + +std::vector Pjrt::GetAttributes() const { + auto args = MakeStruct(); + CheckError(api_->PJRT_Plugin_Attributes(&args)); + std::vector result; + result.reserve(args.num_attributes); + for (size_t i = 0; i < args.num_attributes; ++i) { + result.push_back(MakeKeyValue(args.attributes + i)); + } + return result; +} + +std::unique_ptr Pjrt::CreateClient() { + auto args = MakeStruct(); + CheckError(api_->PJRT_Client_Create(&args)); + return std::make_unique(api_, args.client); +} + +std::pair Pjrt::ApiVersion() const { + return std::make_pair(api_->pjrt_api_version.major_version, + api_->pjrt_api_version.minor_version); +} + +void Pjrt::Initialize() { + auto args = MakeStruct(); + CheckError(api_->PJRT_Plugin_Initialize(&args)); +} + +} // namespace lczero \ No newline at end of file diff --git a/src/neural/xla/pjrt.h b/src/neural/xla/pjrt.h new file mode 100644 index 0000000000..1fd250e34c --- /dev/null +++ b/src/neural/xla/pjrt.h @@ -0,0 +1,252 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +// This file contains set of C++ wrappers around the PJRT C API. + +#pragma once + +#include +#include +#include +#include + +struct PJRT_Api; +struct PJRT_Buffer; +struct PJRT_Client; +struct PJRT_Device; +struct PJRT_DeviceDescription; +struct PJRT_Error; +struct PJRT_Event; +struct PJRT_LoadedExecutable; + +namespace lczero { + +// PJRT_Error_Code as enum class. Coincidentally, the error codes are the same +// as in absl Status module. +enum class PjrtErrorCode { + CANCELLED = 1, + UNKNOWN = 2, + INVALID_ARGUMENT = 3, + DEADLINE_EXCEEDED = 4, + NOT_FOUND = 5, + ALREADY_EXISTS = 6, + PERMISSION_DENIED = 7, + RESOURCE_EXHAUSTED = 8, + FAILED_PRECONDITION = 9, + ABORTED = 10, + OUT_OF_RANGE = 11, + UNIMPLEMENTED = 12, + INTERNAL = 13, + UNAVAILABLE = 14, + DATA_LOSS = 15, + UNAUTHENTICATED = 16 +}; + +// PJRT_Type as enum class. Conincidentally, the types are the same as in XLA +// and HloModuleProto, so simple cast works. +enum class PjrtType { + INVALID, + PRED, + S8, + S16, + S32, + S64, + U8, + U16, + U32, + U64, + F16, + F32, + F64, + BF16, + C64, + C128, + F8E5M2, + F8E4M3FN, + F8E4M3B11FNUZ, + F8E5M2FNUZ, + F8E4M3FNUZ, + S4, + U4, +}; + +// PJRT errors as exceptions. +class PjrtException : public std::exception { + public: + explicit PjrtException(PjrtErrorCode code, const std::string& message) + : message_(message), code_(code) {} + + const char* what() const noexcept override { return message_.data(); } + PjrtErrorCode code() const { return code_; } + + private: + std::string message_; + PjrtErrorCode code_; +}; + +// PJRT_NamedValue wrapper. PJRT_NamedValue is a string-keyed values that are +// used for auxiliary functionality like plugin attributes. +class PjrtKeyValue { + public: + PjrtKeyValue() = default; + PjrtKeyValue(const PjrtKeyValue&) = default; + PjrtKeyValue(PjrtKeyValue&&) = default; + template + PjrtKeyValue(const std::string& k, const T& v) : key_(k), value_(v) {} + + const std::string& key() const { return key_; } + // Converts the value to string. This is useful for logging and debugging. + std::string value_as_string() const; + + void set_key(const std::string& key) { key_ = key; } + void set_value(const std::string& value) { value_ = value; } + void set_value(int64_t value) { value_ = value; } + void set_value(const std::vector& value) { value_ = value; } + void set_value(float value) { value_ = value; } + void set_value(bool value) { value_ = value; } + + private: + std::string key_; + std::variant, float, bool> value_; +}; + +// A shared base class for all wrappers. Keeps the API pointer and auxiliary +// functions like error checking. +class PjrtCommon { + protected: + PjrtCommon(const PJRT_Api* api) : api_(api) {} + virtual ~PjrtCommon() = default; + + std::string GetErrorMessage(PJRT_Error* error) const; + void DestroyErrorMessage(PJRT_Error* error) const; + void CheckError(PJRT_Error* error) const; + + const PJRT_Api* api_; +}; + +class PjrtDevice : protected PjrtCommon { + public: + PjrtDevice(const PJRT_Api* api, PJRT_Device* device); + std::string ToString() const; + + private: + PJRT_Device* device_; + PJRT_DeviceDescription* description_; + friend class PjrtExecutable; + friend class PjrtClient; +}; + +// An event for waiting for asynchronous operations. +class PjrtEvent : protected PjrtCommon { + public: + PjrtEvent(const PJRT_Api* api, PJRT_Event* event); + // Blocks until the operation is complete. + void Await(); + ~PjrtEvent(); + + private: + PJRT_Event* event_; +}; + +// A buffer in the device memory. +class PjrtDeviceBuffer : protected PjrtCommon { + public: + PjrtDeviceBuffer(const PJRT_Api* api, PJRT_Buffer* buffer); + ~PjrtDeviceBuffer(); + // Returns the size of the buffer in bytes. + size_t GetSize() const; + // Starts an asynchronous copy from the device to the host memory. + [[nodiscard]] std::unique_ptr DeviceToHost(void* dst, size_t size); + PjrtType GetType() const; + std::vector GetDimensions() const; + + private: + PJRT_Buffer* buffer_; + friend class PjrtExecutable; +}; + +class PjrtExecutable : protected PjrtCommon { + public: + PjrtExecutable(const PJRT_Api* api, PJRT_LoadedExecutable* executable); + ~PjrtExecutable(); + // Executes the executable with the given inputs. The inputs are not owned or + // modified. The function allocates the output buffers and returns them. + std::vector> ExecuteBlocking( + const std::vector& inputs); + size_t GetNumOutputs() const; + + private: + PJRT_LoadedExecutable* executable_; + size_t num_outputs_; +}; + +// Ongoing host-to-device transfer. After the transfer is complete, it's +// possible to fetch the device buffer. +class PjrtHostToDeviceTransfer : protected PjrtCommon { + public: + PjrtHostToDeviceTransfer(const PJRT_Api* api, PJRT_Buffer* buffer, + std::unique_ptr event); + ~PjrtHostToDeviceTransfer(); + // Blocks until the transfer is complete. (not really necessary as + // AwaitAndReleaseBuffer() waits anyway) + void Await(); + // Waits for the transfer to complete and releases the ownership of the + // buffer. + std::unique_ptr AwaitAndReleaseBuffer(); + + private: + PJRT_Buffer* buffer_; + std::unique_ptr event_; +}; + +class PjrtClient : protected PjrtCommon { + public: + PjrtClient(const PJRT_Api* api, PJRT_Client* client); + ~PjrtClient(); + std::unique_ptr CompileHlo(std::string_view hlo, + std::string_view config); + std::vector> GetDevices(); + std::unique_ptr HostToDevice( + std::string_view buffer, PjrtType type, const std::vector& dims, + const PjrtDevice* device); + + private: + PJRT_Client* client_; +}; + +class Pjrt : protected PjrtCommon { + public: + Pjrt(const char* library_path); + std::vector GetAttributes() const; + std::unique_ptr CreateClient(); + std::pair ApiVersion() const; + + private: + void Initialize(); +}; + +} // namespace lczero diff --git a/src/neural/xla/print_hlo.cc b/src/neural/xla/print_hlo.cc new file mode 100644 index 0000000000..fa7d406bc3 --- /dev/null +++ b/src/neural/xla/print_hlo.cc @@ -0,0 +1,390 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#include "neural/xla/print_hlo.h" + +namespace lczero { +namespace { + +// C-escapes the given string, and appends double quotes around it. +std::string CEscape(std::string_view str) { + std::string result = "\""; + for (char c : str) { + switch (c) { + case '\n': + result += "\\n"; + break; + case '\t': + result += "\\t"; + break; + case '\r': + result += "\\r"; + break; + case '\v': + result += "\\v"; + break; + case '\f': + result += "\\f"; + break; + case '\\': + result += "\\\\"; + break; + case '\"': + result += "\\\""; + break; + case '\'': + result += "\\\'"; + break; + default: + result += c; + } + } + return result + "\""; +} + +class HloPrettyPrinter { + public: + HloPrettyPrinter(PrettyPrintHloOptions options, std::ostream& stream) + : options_(options), s_(stream) {} + + void PrintModule(const pblczero::HloModuleProto& module) { + current_module_ = &module; + s_ << "HloModule " << module.name(); + if (module.has_host_program_shape()) { + s_ << ", entry_computation_layout="; + PrintProgramShape(module.host_program_shape()); + } + s_ << "\n"; + + for (const auto& computation : module.computations()) { + s_ << "\n"; + if (module.entry_computation_id() == computation.id()) s_ << "ENTRY "; + PrintComputation(computation); + } + current_module_ = nullptr; + } + + private: + // Prints delimeted list, with value rendering function and with optional + // prefix and suffix. E.g. for vec=[1,2,3] and f=(x) -> print(x, x), delim="," + // and prefix="{" and suffix="}" it will print "{11,22,33}". + template + void PrintDelimeted(const T& vec, F print_fn, std::string_view delim, + std::string_view prefix = "", + std::string_view suffix = "") { + s_ << prefix; + for (size_t i = 0; i < vec.size(); ++i) { + if (i > 0) s_ << delim; + print_fn(vec[i]); + } + s_ << suffix; + } + + // Returns the name of the type, which is the lowercase enum value name. + std::string GetTypeLiteral(pblczero::XlaShapeProto::Type type) { + std::string name = pblczero::XlaShapeProto::Type_Name(type); + for (char& c : name) c = std::tolower(c); + return name; + } + + // Prints the tensor layout (e.g. {3,2,1,0} for major-to-minor layout). + void PrintLayout(const pblczero::XlaLayoutProto& layout) { + if (!options_.print_layout) return; + PrintDelimeted( + layout.minor_to_major(), [&](const auto& dim) { s_ << dim; }, ",", " {", + "}"); + } + + // Prints the shape of a tensor, including the type (e.g. f32[112,8,8]). + void PrintShape(const pblczero::XlaShapeProto& shape) { + if (shape.element_type() == pblczero::XlaShapeProto::TUPLE) { + PrintDelimeted( + shape.tuple_shapes(), [&](const auto& s) { PrintShape(s); }, ", ", + "(", ")"); + return; + } + s_ << GetTypeLiteral(shape.element_type()); + PrintDelimeted( + shape.dimensions(), [&](int64_t dim) { s_ << dim; }, ",", "[", "]"); + if (shape.has_layout()) PrintLayout(shape.layout()); + } + + // Prints the program shape (i.e. shapes of parameters and output). + void PrintProgramShape(const pblczero::XlaProgramShapeProto& shape) { + s_ << "{("; + for (size_t i = 0; i < shape.parameters_size(); ++i) { + if (i > 0) s_ << ", "; + if (shape.parameter_names_size() > i && + !shape.parameter_names(i).empty()) { + s_ << shape.parameter_names(i) << ": "; + } + PrintShape(shape.parameters(i)); + } + s_ << ") -> "; + PrintShape(shape.result()); + s_ << "}"; + } + + // Prints the literal (i.e. constant value). + void PrintLiteral(const pblczero::XlaLiteralProto& literal) { + // For now just print as a flat array with sometimes wrong encoding (i.e. in + // bf16 case). + auto print_array = [&](const auto& array) { + PrintDelimeted( + array, + [&](const auto& x) { + if constexpr (std::is_same_v, char> || + std::is_same_v, bool>) { + s_ << static_cast(x); + } else { + s_ << x; + } + }, + ","); + }; + switch (literal.shape().element_type()) { + case pblczero::XlaShapeProto::TUPLE: + PrintDelimeted( + literal.tuple_literals(), [&](const auto& l) { PrintLiteral(l); }, + ", ", "(", ")"); + break; + case pblczero::XlaShapeProto::TOKEN: + s_ << "token"; + break; + case pblczero::XlaShapeProto::PRED: + return print_array(literal.preds()); + case pblczero::XlaShapeProto::S4: + return print_array(literal.s4s()); + case pblczero::XlaShapeProto::U4: + return print_array(literal.u4s()); + case pblczero::XlaShapeProto::S8: + return print_array(literal.s8s()); + case pblczero::XlaShapeProto::U8: + return print_array(literal.u8s()); + case pblczero::XlaShapeProto::S32: + return print_array(literal.s32s()); + case pblczero::XlaShapeProto::S64: + return print_array(literal.s64s()); + case pblczero::XlaShapeProto::U32: + return print_array(literal.u32s()); + case pblczero::XlaShapeProto::U64: + return print_array(literal.u64s()); + case pblczero::XlaShapeProto::F32: + return print_array(literal.f32s()); + case pblczero::XlaShapeProto::F64: + return print_array(literal.f64s()); + case pblczero::XlaShapeProto::C64: + return print_array(literal.c64s()); + case pblczero::XlaShapeProto::C128: + return print_array(literal.c128s()); + case pblczero::XlaShapeProto::F16: + return print_array(literal.f16s()); + case pblczero::XlaShapeProto::BF16: + return print_array(literal.bf16s()); + case pblczero::XlaShapeProto::U16: + return print_array(literal.u16s()); + case pblczero::XlaShapeProto::S16: + return print_array(literal.s16s()); + case pblczero::XlaShapeProto::F8E5M2: + return print_array(literal.f8e5m2s()); + case pblczero::XlaShapeProto::F8E4M3FN: + return print_array(literal.f8e4m3fns()); + case pblczero::XlaShapeProto::F8E4M3B11FNUZ: + return print_array(literal.f8e4m3b11fnuzs()); + case pblczero::XlaShapeProto::F8E5M2FNUZ: + return print_array(literal.f8e5m2fnuzs()); + case pblczero::XlaShapeProto::F8E4M3FNUZ: + return print_array(literal.f8e4m3fnuzs()); + case pblczero::XlaShapeProto::PRIMITIVE_TYPE_INVALID: + s_ << "INVALID"; + break; + case pblczero::XlaShapeProto::OPAQUE_TYPE: + s_ << "opaque"; + break; + } + } + + // Prints the operands of the given instruction. Usually operands are stored + // in operands() fields, but some opcodes have operands in the other fields. + void PrintInstructionOperands( + const pblczero::HloInstructionProto& instruction) { + s_ << "("; + if (instruction.opcode() == "parameter") { + s_ << instruction.parameter_number(); + } else if (instruction.opcode() == "get-tuple-index") { + s_ << instruction.tuple_index(); + } else if (instruction.opcode() == "constant") { + PrintLiteral(instruction.literal()); + } else { + PrintDelimeted( + instruction.operand_ids(), + [&](int64_t id) { + s_ << "%" << current_computation_->instructions(id).name(); + }, + ", "); + } + s_ << ")"; + } + + // Prints the "window" attribute (for convolution opcodes). + void PrintWindow(const pblczero::XlaWindow& window) { + PrintDelimeted( + window.dimensions(), [&](const auto& d) { s_ << d.size(); }, "x", + "size="); + PrintDelimeted( + window.dimensions(), + [&](const auto& d) { + s_ << d.padding_low() << "_" << d.padding_high(); + }, + "x", " pads="); + } + + // Prints the "dimension numbers" attribute (for dot opcodes). + void PrintDotDimensionNumbers(const pblczero::XlaDotDimensionNumbers& dn) { + PrintDelimeted( + dn.lhs_batch_dimensions(), [&](int64_t dim) { s_ << dim; }, ",", + ", lhs_batch_dims={", "}"); + PrintDelimeted( + dn.rhs_batch_dimensions(), [&](int64_t dim) { s_ << dim; }, ",", + ", rhs_batch_dims={", "}"); + PrintDelimeted( + dn.lhs_contracting_dimensions(), [&](int64_t dim) { s_ << dim; }, ",", + ", lhs_contracting_dims={", "}"); + PrintDelimeted( + dn.rhs_contracting_dimensions(), [&](int64_t dim) { s_ << dim; }, ",", + ", rhs_contracting_dims={", "}"); + } + + // Prints the "dimension numbers" attribute (for convolution opcodes). + void PrintConvolutionDimensionNumbers( + const pblczero::XlaConvolutionDimensionNumbers& dn) { + std::string input_dims(dn.input_spatial_dimensions_size() + 2, '?'); + std::string kernel_dims(dn.kernel_spatial_dimensions_size() + 2, '?'); + std::string output_dims(dn.output_spatial_dimensions_size() + 2, '?'); + input_dims[dn.input_batch_dimension()] = 'b'; + input_dims[dn.input_feature_dimension()] = 'f'; + kernel_dims[dn.kernel_output_feature_dimension()] = 'o'; + kernel_dims[dn.kernel_input_feature_dimension()] = 'i'; + output_dims[dn.output_batch_dimension()] = 'b'; + output_dims[dn.output_feature_dimension()] = 'f'; + for (size_t i = 0; i < dn.input_spatial_dimensions_size(); ++i) { + input_dims[dn.input_spatial_dimensions(i)] = '0' + i; + kernel_dims[dn.kernel_spatial_dimensions(i)] = '0' + i; + output_dims[dn.output_spatial_dimensions(i)] = '0' + i; + } + s_ << input_dims << "_" << kernel_dims << "->" << output_dims; + } + + // Prints the attributes of the given instruction. + void PrintInstructionAttributes( + const pblczero::HloInstructionProto& instruction) { + if (instruction.called_computation_ids_size() > 0) { + PrintDelimeted( + instruction.called_computation_ids(), + [&](int64_t id) { s_ << current_module_->computations(id).name(); }, + ",", ", calls={", "}"); + } + if (instruction.has_window()) { + s_ << ", window={"; + PrintWindow(instruction.window()); + s_ << "}"; + } + if (instruction.has_convolution_dimension_numbers()) { + s_ << ", dim_labels="; + PrintConvolutionDimensionNumbers( + instruction.convolution_dimension_numbers()); + } + if (instruction.dimensions_size() > 0) { + PrintDelimeted( + instruction.dimensions(), [&](int64_t dim) { s_ << dim; }, ", ", + ", dimensions={", "}"); + } + if (instruction.has_dot_dimension_numbers()) { + PrintDotDimensionNumbers(instruction.dot_dimension_numbers()); + } + } + + // Prints the metadata of the given instruction (source file, line, etc). + void PrintInstructionMetadata( + const pblczero::HloInstructionProto& instruction) { + if (instruction.has_metadata()) { + const auto& m = instruction.metadata(); + s_ << ", metadata={"; + bool first = true; + auto sep = [&]() -> std::ostream& { + if (!first) s_ << ", "; + first = false; + return s_; + }; + std::vector bits; + if (m.has_op_type()) sep() << "op_type=" << CEscape(m.op_type()); + if (m.has_op_name()) sep() << "op_name=" << CEscape(m.op_name()); + if (m.has_source_file()) + sep() << "source_file=" << CEscape(m.source_file()); + if (m.has_source_line()) sep() << "source_line=" << m.source_line(); + s_ << "}"; + } + } + + // Prints the given instruction line. + void PrintInstruction(const pblczero::HloInstructionProto& instruction) { + s_ << "%" << instruction.name() << " = "; + PrintShape(instruction.shape()); + s_ << " " << instruction.opcode(); + PrintInstructionOperands(instruction); + PrintInstructionAttributes(instruction); + PrintInstructionMetadata(instruction); + } + + // Prints the given computation. + void PrintComputation(const pblczero::HloComputationProto& computation) { + current_computation_ = &computation; + s_ << computation.name() << " {\n"; + for (const auto& instruction : computation.instructions()) { + s_ << " "; + if (computation.root_id() == instruction.id()) s_ << "ROOT "; + PrintInstruction(instruction); + s_ << "\n"; + } + s_ << "}\n"; + current_computation_ = nullptr; + } + + PrettyPrintHloOptions options_; + const pblczero::HloModuleProto* current_module_ = nullptr; + const pblczero::HloComputationProto* current_computation_ = nullptr; + std::ostream& s_; +}; + +} // namespace + +void PrettyPrintHlo(const pblczero::HloModuleProto& module, + PrettyPrintHloOptions options, std::ostream& stream) { + HloPrettyPrinter(options, stream).PrintModule(module); +} + +} // namespace lczero \ No newline at end of file diff --git a/src/neural/xla/print_hlo.h b/src/neural/xla/print_hlo.h new file mode 100644 index 0000000000..28db989ef4 --- /dev/null +++ b/src/neural/xla/print_hlo.h @@ -0,0 +1,44 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#include + +#include "neural/xla/hlo.pb.h" + +namespace lczero { + +struct PrettyPrintHloOptions { + // Print layout information (which is always major-to-minor now, e.g. + // {3,2,1,0}. Switched off by default as it's just noise. + bool print_layout = false; +}; + +// Pretty-prints the given HLO module to the given stream. +void PrettyPrintHlo(const pblczero::HloModuleProto& module, + PrettyPrintHloOptions options, std::ostream& stream); + +} // namespace lczero \ No newline at end of file diff --git a/src/neural/xla/xla_runner.cc b/src/neural/xla/xla_runner.cc new file mode 100644 index 0000000000..922d38f7fc --- /dev/null +++ b/src/neural/xla/xla_runner.cc @@ -0,0 +1,214 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#include "neural/xla/xla_runner.h" + +#include + +#include "utils/exception.h" +#include "utils/logging.h" + +namespace lczero { +namespace { + +size_t GetTypeSize(pblczero::XlaShapeProto::Type type) { + switch (type) { + case pblczero::XlaShapeProto::F32: + return sizeof(float); + case pblczero::XlaShapeProto::F64: + return sizeof(double); + case pblczero::XlaShapeProto::S32: + return sizeof(int32_t); + case pblczero::XlaShapeProto::S64: + return sizeof(int64_t); + default: + throw Exception("Add size for type " + + pblczero::XlaShapeProto::Type_Name(type)); + } +} + +std::string AsHexString(std::string_view buf) { + std::string result; + result.reserve(buf.size() * 2); + constexpr char hex[] = "0123456789abcdef"; + for (unsigned char c : buf) { + result.push_back(hex[c >> 4]); + result.push_back(hex[c & 0xf]); + } + return result; +} + +} // namespace + +std::string XlaTensor::DebugString() { + constexpr size_t kMaxSize = 1000; + constexpr size_t kSuffixSize = 200; + std::string result = "XlaTensor("; + result += "shape=["; + for (size_t i = 0; i < shape().size(); ++i) { + if (i > 0) result += ", "; + result += std::to_string(shape()[i]); + } + result += "], type="; + result += pblczero::XlaShapeProto::Type_Name(type()); + result += ") size=" + std::to_string(size()); + result += " data="; + if (size() <= kMaxSize) { + result += AsHexString({static_cast(data()), size()}); + } else { + result += AsHexString( + {static_cast(data()), kMaxSize - kSuffixSize - 2}); + result += "...."; + result += AsHexString( + {static_cast(data()) + size() - kSuffixSize, kSuffixSize}); + } + return result; +} + +XlaRunner::XlaRunner(const char* library_path, int device) + : pjrt_client_(Pjrt(library_path).CreateClient()), device_(device) { + CERR << "Devices:"; + devices_ = pjrt_client_->GetDevices(); + for (const auto& device : devices_) { + CERR << " " << device->ToString(); + } + if (devices_.empty()) { + throw Exception("No devices available"); + } +} + +void XlaRunner::AddModule(size_t minibatch_size, + const pblczero::HloModuleProto& module) { + pblczero::CompileOptionsProto options; + options.mutable_executable_build_options()->set_num_replicas(1); + options.mutable_executable_build_options()->set_num_partitions(1); + options.mutable_executable_build_options()->set_device_ordinal(device_); + auto executable = pjrt_client_->CompileHlo(module.OutputAsString(), + options.OutputAsString()); + executables_.push_back({minibatch_size, std::move(executable)}); + std::sort(executables_.begin(), executables_.end()); +} + +void XlaRunner::SetFrozenInputs( + const std::vector> inputs) { + param_idxs_.clear(); + std::vector> transfers_; + for (size_t i = 0; i < inputs.size(); ++i) { + const auto* input = inputs[i].get(); + if (!input) { + param_idxs_.push_back(i); + continue; + } + transfers_.push_back(pjrt_client_->HostToDevice( + {static_cast(input->data()), input->size()}, + static_cast(input->type()), input->shape(), + devices_.at(device_).get())); + } + + owned_buffers_.clear(); + buffers_.clear(); + buffers_.resize(inputs.size()); + size_t transfer_idx = 0; + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i]) { + owned_buffers_.push_back( + transfers_[transfer_idx++]->AwaitAndReleaseBuffer()); + buffers_[i] = owned_buffers_.back().get(); + } + } +} + +size_t XlaRunner::GetMaxBatchSize() const { return executables_.back().first; } + +std::vector> XlaRunner::ExecuteBlocking( + const std::vector& inputs) { + if (inputs.size() != 1) { + throw Exception("Only one input is kinda supported."); + } + // Find the smallest batch size that fits the input. + auto iter = std::find_if( + executables_.begin(), executables_.end(), [&](const auto& e) { + return e.first >= static_cast(inputs[0]->shape()[0]); + }); + if (iter == executables_.end()) { + throw Exception("No executable found for batch size " + + std::to_string(inputs[0]->shape()[0])); + } + const size_t batch_size = iter->first; + // Update the shape to match the rounded up batch size. After growing, the + // batch size must fit within tensor buffer capacity (it's fine to have + // garbage in the tail of that buffer). + std::vector new_shape = inputs[0]->shape(); + new_shape[0] = batch_size; + const size_t input_size = std::accumulate(new_shape.begin(), new_shape.end(), + 1, std::multiplies()) * + GetTypeSize(inputs[0]->type()); + if (input_size > inputs[0]->capacity()) { + throw Exception("Input buffer too small"); + } + // Transfer the input to the device. + auto input_buffer = + pjrt_client_ + ->HostToDevice( + {static_cast(inputs[0]->data()), input_size}, + static_cast(inputs[0]->type()), new_shape, + devices_[0].get()) + ->AwaitAndReleaseBuffer(); + // Make a copy to support multiple concurrent calls, not sure if it's needed. + auto input_buffers = buffers_; + input_buffers[param_idxs_[0]] = input_buffer.get(); + // Execute! + auto outputs = iter->second->ExecuteBlocking(input_buffers); + + // Now we need to transfer the outputs back to the host. + std::vector> result; + result.reserve(outputs.size()); + std::vector output_buffers; + std::vector> done_events; + output_buffers.reserve(outputs.size()); + done_events.reserve(outputs.size()); + // Initialte transfers from device to host. + for (size_t i = 0; i < outputs.size(); ++i) { + const auto& output = outputs[i]; + output_buffers.emplace_back(); + auto& buffer = output_buffers.back(); + buffer.resize(output->GetSize()); + done_events.push_back(output->DeviceToHost(&buffer[0], buffer.size())); + } + // Wait for the transfers to complete. + for (size_t i = 0; i < outputs.size(); ++i) { + const auto& output = outputs[i]; + done_events[i]->Await(); + result.push_back(std::make_unique( + output->GetDimensions(), + static_cast(output->GetType()), + std::move(output_buffers[i]))); + } + return result; +} + +} // namespace lczero \ No newline at end of file diff --git a/src/neural/xla/xla_runner.h b/src/neural/xla/xla_runner.h new file mode 100644 index 0000000000..85352c4142 --- /dev/null +++ b/src/neural/xla/xla_runner.h @@ -0,0 +1,134 @@ +/* + This file is part of Leela Chess Zero. + Copyright (C) 2024 The LCZero Authors + + Leela Chess is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Leela Chess is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with Leela Chess. If not, see . + + Additional permission under GNU GPL version 3 section 7 + + If you modify this Program, or any covered work, by linking or + combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA + Toolkit and the NVIDIA CUDA Deep Neural Network library (or a + modified version of those libraries), containing parts covered by the + terms of the respective license agreement, the licensors of this + Program grant you additional permission to convey the resulting work. +*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include "neural/xla/hlo.pb.h" +#include "neural/xla/pjrt.h" + +namespace lczero { + +// An interface for in-host-memory tensor in XLA format. +class XlaTensor { + public: + virtual ~XlaTensor() = default; + virtual const std::vector& shape() const = 0; + virtual const void* data() const = 0; + // Returns amount of valid data in bytes. + virtual size_t size() const = 0; + // Returns amount of memory that are allowed to address in bytes. This is + // useful when the size of the buffer has to be increased to match the + // supported batch size. + virtual size_t capacity() const = 0; + virtual pblczero::XlaShapeProto::Type type() const = 0; + + std::string DebugString(); +}; + +// Not-owned XLA tensor, used e.g. when ONNX buffer can be used directly, to +// avoid unnecessary copy. +class XlaTensorNotOwned : public XlaTensor { + public: + XlaTensorNotOwned(const std::vector& shape, std::string_view data, + pblczero::XlaShapeProto::Type type) + : shape_(&shape), data_(data), type_(type) {} + + const std::vector& shape() const override { return *shape_; } + const void* data() const override { return data_.data(); } + size_t size() const override { return data_.size(); } + size_t capacity() const override { return data_.size(); } + pblczero::XlaShapeProto::Type type() const override { return type_; } + + private: + const std::vector* shape_; + std::string_view data_; + pblczero::XlaShapeProto::Type type_; +}; + +// Tensor that owns data, used e.g. for XLA output. +class XlaTensorOwned : public XlaTensor { + public: + XlaTensorOwned(const std::vector& shape, + pblczero::XlaShapeProto::Type type, std::string data) + : shape_(shape), type_(type), data_(data) {} + + const std::vector& shape() const override { return shape_; } + const void* data() const override { return data_.data(); } + size_t size() const override { return data_.size(); } + size_t capacity() const override { return data_.size(); } + pblczero::XlaShapeProto::Type type() const override { return type_; } + + private: + std::vector shape_; + pblczero::XlaShapeProto::Type type_; + std::string data_; +}; + +// A class that keeps several XLA executables (for different batch sizes), +// manages common buffers among them, and chooses the right executable for a +// batch size. +class XlaRunner { + public: + // The library_path is the path to the PJRT library, and device indx. + XlaRunner(const char* library_path, int device); + // Compiles and adds a module for the given batch size. + void AddModule(size_t minibatch_size, const pblczero::HloModuleProto& module); + // Transfers inputs to the device and execute the executable corresponding to + // the batch size. Only non-frozen inputs are passed as arguments. + // Currnetly only single input is supported (just because we don't need more). + std::vector> ExecuteBlocking( + const std::vector& inputs); + // Inputs that are shared between all calls (i.e. network weights passed as + // parameters). These inputs are transferred to device immediately (and not + // for each inference). + void SetFrozenInputs(const std::vector> inputs); + // Maximum supported batch size. It's expected that the capacity (not size) of + // the input tensors would be able to fit this size. + size_t GetMaxBatchSize() const; + + private: + std::unique_ptr pjrt_client_; + std::vector> devices_; + // Compiled executables per batch size. + std::vector>> executables_; + // Frozen inputs, in no particular order, kept for ownership. + std::vector> owned_buffers_; + // Vector of pointers to all input buffers, that is passed to PJRT. Frozen + // parameters (constants) are pre-filled in SetFrozenInputs(), and non-frozen + // inputs (input planes) are created and filled in every request. + std::vector buffers_; + std::vector param_idxs_; + int device_; +}; + +} // namespace lczero \ No newline at end of file diff --git a/third_party/pjrt_c_api.h b/third_party/pjrt_c_api.h new file mode 100644 index 0000000000..f1ab16f1a6 --- /dev/null +++ b/third_party/pjrt_c_api.h @@ -0,0 +1,2175 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_C_PJRT_C_API_H_ +#define XLA_PJRT_C_PJRT_C_API_H_ + +#include +#include +#include + +#define PJRT_STRUCT_SIZE(struct_type, last_field) \ + offsetof(struct_type, last_field) + sizeof(((struct_type*)0)->last_field) + +#define PJRT_DEFINE_STRUCT_TRAITS(sname, last_field) \ + typedef struct sname sname; \ + enum { sname##_STRUCT_SIZE = PJRT_STRUCT_SIZE(sname, last_field) } + +#ifdef __cplusplus +extern "C" { +#endif + +// --------------------------------- Version ----------------------------------- + +// Incremented when an ABI-incompatible change is made to the interface. +// Changes include: +// * Deleting a method or argument +// * Changing the type of an argument +// * Rearranging fields in the PJRT_Api or argument structs +#define PJRT_API_MAJOR 0 + +// Incremented when the interface is updated in a way that is potentially +// ABI-compatible with older versions, if supported by the caller and/or +// implementation. +// +// Callers can implement forwards compatibility by using PJRT_Api_Version to +// check if the implementation is aware of newer interface additions. +// +// Implementations can implement backwards compatibility by using the +// `struct_size` fields to detect how many struct fields the caller is aware of. +// +// Changes include: +// * Adding a new field to the PJRT_Api or argument structs +// * Renaming a method or argument (doesn't affect ABI) +#define PJRT_API_MINOR 40 + +// The plugin should set the major_version and minor_version of +// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in +// this header that the implementation was compiled with. +struct PJRT_Api_Version { + size_t struct_size; + void* priv; + int major_version; // out + int minor_version; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Api_Version, minor_version); + +// ---------------------------------- Errors ----------------------------------- + +// PJRT C API methods generally return a PJRT_Error*, which is nullptr if there +// is no error and set if there is. The implementation allocates any returned +// PJRT_Errors, but the caller is always responsible for freeing them via +// PJRT_Error_Destroy. + +typedef struct PJRT_Error PJRT_Error; + +struct PJRT_Error_Destroy_Args { + size_t struct_size; + void* priv; + PJRT_Error* error; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Error_Destroy_Args, error); + +// Frees `error`. `error` can be nullptr. +typedef void PJRT_Error_Destroy(PJRT_Error_Destroy_Args* args); + +struct PJRT_Error_Message_Args { + size_t struct_size; + void* priv; + const PJRT_Error* error; + // Has the lifetime of `error`. + const char* message; // out + size_t message_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Error_Message_Args, message_size); + +// Gets the human-readable reason for `error`. `message` has the lifetime of +// `error`. +typedef void PJRT_Error_Message(PJRT_Error_Message_Args* args); + +// Codes are based on https://abseil.io/docs/cpp/guides/status-codes +typedef enum { + PJRT_Error_Code_CANCELLED = 1, + PJRT_Error_Code_UNKNOWN = 2, + PJRT_Error_Code_INVALID_ARGUMENT = 3, + PJRT_Error_Code_DEADLINE_EXCEEDED = 4, + PJRT_Error_Code_NOT_FOUND = 5, + PJRT_Error_Code_ALREADY_EXISTS = 6, + PJRT_Error_Code_PERMISSION_DENIED = 7, + PJRT_Error_Code_RESOURCE_EXHAUSTED = 8, + PJRT_Error_Code_FAILED_PRECONDITION = 9, + PJRT_Error_Code_ABORTED = 10, + PJRT_Error_Code_OUT_OF_RANGE = 11, + PJRT_Error_Code_UNIMPLEMENTED = 12, + PJRT_Error_Code_INTERNAL = 13, + PJRT_Error_Code_UNAVAILABLE = 14, + PJRT_Error_Code_DATA_LOSS = 15, + PJRT_Error_Code_UNAUTHENTICATED = 16 +} PJRT_Error_Code; + +struct PJRT_Error_GetCode_Args { + size_t struct_size; + void* priv; + const PJRT_Error* error; + PJRT_Error_Code code; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Error_GetCode_Args, code); + +typedef PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args); + +// Function for PJRT implementation to pass to callback functions provided by +// caller so the callback can create a PJRT_Error* on error (to return to the +// implementation). `message` is only required to live for the +// PJRT_CallbackError call, i.e. the PJRT_CallbackError implementation must copy +// `message` into the PJRT_Error. +typedef PJRT_Error* (*PJRT_CallbackError)(PJRT_Error_Code code, + const char* message, + size_t message_size); + +// ---------------------------- Named Values ----------------------------------- + +typedef enum { + PJRT_NamedValue_kString = 0, + PJRT_NamedValue_kInt64, + PJRT_NamedValue_kInt64List, + PJRT_NamedValue_kFloat, + PJRT_NamedValue_kBool, +} PJRT_NamedValue_Type; + +// Named value for key-value pairs. +struct PJRT_NamedValue { + size_t struct_size; + void* priv; + const char* name; + size_t name_size; + PJRT_NamedValue_Type type; + union { + const char* string_value; + int64_t int64_value; + const int64_t* int64_array_value; + float float_value; + bool bool_value; + }; + // `value_size` is the number of elements for array/string and 1 for scalar + // values. + size_t value_size; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_NamedValue, value_size); + +// ---------------------------------- Plugin ----------------------------------- + +struct PJRT_Plugin_Initialize_Args { + size_t struct_size; + void* priv; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Initialize_Args, priv); + +// One-time plugin setup. Must be called before any other functions are called. +typedef PJRT_Error* PJRT_Plugin_Initialize(PJRT_Plugin_Initialize_Args* args); + +struct PJRT_Plugin_Attributes_Args { + size_t struct_size; + void* priv; + // Returned attributes have the lifetime of the process. + const PJRT_NamedValue* attributes; // out + size_t num_attributes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Attributes_Args, attributes); + +// Returns an array of plugin attributes which are key-value pairs. One example +// attribute is the minimum supported StableHLO version. +// TODO(b/280349977): standardize the list of attributes. +typedef PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args); + +// ---------------------------------- Events ----------------------------------- + +// Represents a notifying event that is returned by PJRT APIs that enqueue +// asynchronous work, informing callers when the work is complete and reporting +// a value of type `PJRT_Error*` or `nullptr` as error status. +// +// Callers are always responsible for freeing `PJRT_Event`s by calling +// `PJRT_Event_Destroy`. +typedef struct PJRT_Event PJRT_Event; + +struct PJRT_Event_Destroy_Args { + size_t struct_size; + void* priv; + PJRT_Event* event; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Destroy_Args, event); + +// Frees `event`. `event` can be `nullptr`. +typedef PJRT_Error* PJRT_Event_Destroy(PJRT_Event_Destroy_Args* args); + +struct PJRT_Event_IsReady_Args { + size_t struct_size; + void* priv; + PJRT_Event* event; + bool is_ready; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_IsReady_Args, is_ready); + +// Returns true if this PJRT_Event has completed, including if an error has +// occurred. +typedef PJRT_Error* PJRT_Event_IsReady(PJRT_Event_IsReady_Args* args); + +struct PJRT_Event_Error_Args { + size_t struct_size; + void* priv; + PJRT_Event* event; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Error_Args, event); + +// Should only be called if PJRT_Event_IsReady returns true. +// Returns `nullptr` if there is no error. +// The returned error should be freed with `PJRT_Error_Destroy`. +// +// If `PJRT_Event_Await` has been called, this will return a pointer to an +// identical error status as that call, as will subsequent calls to +// `PJRT_Event_Error`. However, each of these `PJRT_Error *` pointers are +// independent of `PJRT_Error *`s returned by other function calls, so they must +// each be freed separately using `PJRT_Error_Destroy`. +typedef PJRT_Error* PJRT_Event_Error(PJRT_Event_Error_Args* args); + +struct PJRT_Event_Await_Args { + size_t struct_size; + void* priv; + PJRT_Event* event; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_Await_Args, event); + +// Blocks the calling thread until `event` is ready, then returns the error +// status (with `nullptr` indicating no error). The returned status should be +// freed with `PJRT_Error_Destroy`. +typedef PJRT_Error* PJRT_Event_Await(PJRT_Event_Await_Args* args); + +// A callback to be performed once an event is ready. It will be called on the +// event's error state and a pointer to an object of the caller's choice. +// Ownership of `error` is passed to the callback. The callback must destroy +// `error` via `PJRT_Error_Destroy`. The caller retains ownership of `user_arg`. +typedef void (*PJRT_Event_OnReadyCallback)(PJRT_Error* error, void* user_arg); + +struct PJRT_Event_OnReady_Args { + size_t struct_size; + void* priv; + PJRT_Event* event; + PJRT_Event_OnReadyCallback callback; + // `user_arg` allows `callback` to be called with arbitrary arguments (e.g. + // via pointers in a struct cast to void*). + void* user_arg; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Event_OnReady_Args, user_arg); + +// Registers `callback` to be called once `event` is ready, with `event`'s +// error status and a pointer to an object of the caller's choice as arguments. +typedef PJRT_Error* PJRT_Event_OnReady(PJRT_Event_OnReady_Args* args); + +// ---------------------------------- Client ----------------------------------- + +typedef struct PJRT_Client PJRT_Client; +typedef struct PJRT_Device PJRT_Device; +typedef struct PJRT_Memory PJRT_Memory; +typedef struct PJRT_DeviceDescription PJRT_DeviceDescription; +typedef struct PJRT_TopologyDescription PJRT_TopologyDescription; +typedef struct PJRT_Executable PJRT_Executable; +typedef struct PJRT_LoadedExecutable PJRT_LoadedExecutable; +typedef struct PJRT_Buffer PJRT_Buffer; + +// The caller of PJRT_Client_Create can optionally provide a key-value store +// accessible across nodes and/or processes. KV store access may be necessary to +// create some multi-node/multi-process clients. The caller can provide the two +// callbacks below to access the key-value store. + +// A callback to delete the value returned by PJRT_KeyValueGetCallback. +typedef void (*PJRT_KeyValueGetCallback_ValueDeleter)(char* value); + +struct PJRT_KeyValueGetCallback_Args { + size_t struct_size; + void* priv; + const char* key; + size_t key_size; + int timeout_in_ms; + PJRT_CallbackError* callback_error; + void* user_arg; + char* value; // out + size_t value_size; // out + // The caller needs to set a PJRT_KeyValueGetCallback_ValueDeleter to delete + // the value returned by PJRT_KeyValueGetCallback. The implementation is + // responsible for copying `value` and then calling value_deleter_callback. + PJRT_KeyValueGetCallback_ValueDeleter value_deleter_callback; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueGetCallback_Args, + value_deleter_callback); + +// Requirements for PJRT_KeyValueGetCallback implementation: (1) Thread-safe. +// (2) The caller that provides the two callbacks is responsible for avoiding +// key collisions between different users of key-value store (i.e. between +// different plugins, but not between different nodes in one plugin). (3) +// Blocking. +typedef PJRT_Error* (*PJRT_KeyValueGetCallback)( + PJRT_KeyValueGetCallback_Args* args); + +struct PJRT_KeyValuePutCallback_Args { + size_t struct_size; + void* priv; + const char* key; + size_t key_size; + // Only needs to stay alive for the duration of the PJRT_KeyValuePutCallback + // call. + const char* value; + size_t value_size; + PJRT_CallbackError* callback_error; + void* user_arg; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValuePutCallback_Args, user_arg); + +// Requirements for PJRT_KeyValuePutCallback implementation: (1) Thread-safe. +// (2) The caller that provides the two callbacks is responsible for avoiding +// key collisions between different users of key-value store (i.e. between +// different plugins, but not between different nodes in one plugin). +typedef PJRT_Error* (*PJRT_KeyValuePutCallback)( + PJRT_KeyValuePutCallback_Args* args); + +struct PJRT_Client_Create_Args { + size_t struct_size; + void* priv; + // Extra platform-specific options to create a client. + const PJRT_NamedValue* create_options; + size_t num_options; + // Key-value get/put callback provided by the caller of PJRT_Client_Create. + // PJRT client can use these callbacks to share information between + // processes/nodes. + PJRT_KeyValueGetCallback kv_get_callback; + // Will be passed to `kv_get_callback` as `user_arg` argument. + void* kv_get_user_arg; + PJRT_KeyValuePutCallback kv_put_callback; + // Will be passed to `kv_put_callback` as `user_arg` argument. + void* kv_put_user_arg; + + PJRT_Client* client; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, client); + +// Creates and initializes a new PJRT_Client and returns in `client`. +typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args); + +struct PJRT_Client_Destroy_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Destroy_Args, client); + +// Shuts down and frees `client`. `client` can be nullptr. +typedef PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args); + +struct PJRT_Client_PlatformName_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + // `platform_name` has the same lifetime as `client`. It is owned by `client`. + const char* platform_name; // out + size_t platform_name_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_PlatformName_Args, platform_name_size); + +// Returns a string that identifies the platform (e.g. "cpu", "gpu", "tpu"). +typedef PJRT_Error* PJRT_Client_PlatformName( + PJRT_Client_PlatformName_Args* args); + +struct PJRT_Client_ProcessIndex_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + int process_index; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_ProcessIndex_Args, process_index); + +// Return the process index of this client. Always 0 in single-process +// settings. +typedef PJRT_Error* PJRT_Client_ProcessIndex( + PJRT_Client_ProcessIndex_Args* args); + +struct PJRT_Client_PlatformVersion_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + // `platform_version` has the same lifetime as `client`. It's owned by + // `client`. + const char* platform_version; // out + size_t platform_version_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_PlatformVersion_Args, + platform_version_size); + +// Returns a string containing human-readable, platform-specific version info +// (e.g. the CUDA version on GPU or libtpu version on Cloud TPU). +typedef PJRT_Error* PJRT_Client_PlatformVersion( + PJRT_Client_PlatformVersion_Args* args); + +struct PJRT_Client_TopologyDescription_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + // Is owned by and has the same lifetime as `client`. + PJRT_TopologyDescription* topology; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_TopologyDescription_Args, topology); + +// Returns the topology description of the runtime topology. The returned +// topology is owned by the client and should not be deleted by the caller. +typedef PJRT_Error* PJRT_Client_TopologyDescription( + PJRT_Client_TopologyDescription_Args* args); + +struct PJRT_Client_Devices_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + PJRT_Device* const* devices; // out + size_t num_devices; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Devices_Args, num_devices); + +// Returns a list of all devices visible to the runtime, including addressable +// and non-addressable devices. +typedef PJRT_Error* PJRT_Client_Devices(PJRT_Client_Devices_Args* args); + +struct PJRT_Client_AddressableDevices_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + PJRT_Device* const* addressable_devices; // out + size_t num_addressable_devices; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_AddressableDevices_Args, + num_addressable_devices); + +// Returns a list of devices that are addressable from the client. +// Addressable devices are those that the client can issue commands to. +// All devices are addressable in a single-process environment. +typedef PJRT_Error* PJRT_Client_AddressableDevices( + PJRT_Client_AddressableDevices_Args* args); + +struct PJRT_Client_LookupDevice_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + int id; + // `device` has the same lifetime as `client`. It is owned by `client`. + PJRT_Device* device; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_LookupDevice_Args, device); + +// Returns a PJRT_Device* with the specified ID as returned by +// PJRT_DeviceDescription_Id. +typedef PJRT_Error* PJRT_Client_LookupDevice( + PJRT_Client_LookupDevice_Args* args); + +struct PJRT_Client_LookupAddressableDevice_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + int local_hardware_id; + // `addressable_device` has the same lifetime as `client`. It is owned by + // `client`. + PJRT_Device* addressable_device; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_LookupAddressableDevice_Args, + addressable_device); + +// Returns an addressable PJRT_Device* with the specified ID as returned by +// PJRT_DeviceDescription_LocalHardwareId. +typedef PJRT_Error* PJRT_Client_LookupAddressableDevice( + PJRT_Client_LookupAddressableDevice_Args* args); + +struct PJRT_Client_AddressableMemories_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + PJRT_Memory* const* addressable_memories; // out + size_t num_addressable_memories; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_AddressableMemories_Args, + num_addressable_memories); + +// Returns a list of memories that are addressable from the client. Addressable +// memories are those that the client can directly transfer data to and from. +// All memories are addressable in a single-process environment. +typedef PJRT_Error* PJRT_Client_AddressableMemories( + PJRT_Client_AddressableMemories_Args* args); + +struct PJRT_Program { + size_t struct_size; + void* priv; + // Serialized code in the specified format below. + // String is owned by the caller. + char* code; // in/out depending on usage + size_t code_size; + // Supported formats are: + // "hlo": code string takes serialized HloModuleProto. + // "hlo_with_config": code string takes serialized HloModuleProtoWithConfig. + // "mlir": code string takes MLIR module bytecode (or string). + // Ownership of `format` varies across API functions. + const char* format; + size_t format_size; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Program, format_size); + +struct PJRT_Client_Compile_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + // Only needs to stay alive for the duration of the Compile call. + // `program->format` and `program->format_size` are owned by the caller. + const PJRT_Program* program; + // TODO(b/240560013): consider putting some of option fields in priv. + // Serialized CompileOptionsProto + // (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/compile_options.proto) + const char* compile_options; + size_t compile_options_size; + PJRT_LoadedExecutable* executable; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Compile_Args, executable); + +// Compiles a program in specified format (such as MLIR or HLO) with given +// `options`. +typedef PJRT_Error* PJRT_Client_Compile(PJRT_Client_Compile_Args* args); + +struct PJRT_Client_DefaultDeviceAssignment_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + int num_replicas; + int num_partitions; + // Must be greater than or equal to `num_replicas * num_partitions` + size_t default_assignment_size; + // Points to an array of size `default_assignment_size`. + // This API writes `num_replicas * num_partitions` ints within that buffer. + // The caller retains ownership of this memory. + int* default_assignment; // pointer to array in; values written as out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_DefaultDeviceAssignment_Args, + default_assignment); + +typedef PJRT_Error* PJRT_Client_DefaultDeviceAssignment( + PJRT_Client_DefaultDeviceAssignment_Args* args); + +typedef enum { + // Invalid primitive type to serve as default. + PJRT_Buffer_Type_INVALID, + + // Predicates are two-state booleans. + PJRT_Buffer_Type_PRED, + + // Signed integral values of fixed width. + PJRT_Buffer_Type_S8, + PJRT_Buffer_Type_S16, + PJRT_Buffer_Type_S32, + PJRT_Buffer_Type_S64, + + // Unsigned integral values of fixed width. + PJRT_Buffer_Type_U8, + PJRT_Buffer_Type_U16, + PJRT_Buffer_Type_U32, + PJRT_Buffer_Type_U64, + + // Floating-point values of fixed width. + PJRT_Buffer_Type_F16, + PJRT_Buffer_Type_F32, + PJRT_Buffer_Type_F64, + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + // and 7 bits for the mantissa. + PJRT_Buffer_Type_BF16, + + // Complex values of fixed width. + // + // Paired F32 (real, imag), as in std::complex. + PJRT_Buffer_Type_C64, + // Paired F64 (real, imag), as in std::complex. + PJRT_Buffer_Type_C128, + + // Truncated 8 bit floating-point formats. + PJRT_Buffer_Type_F8E5M2, + PJRT_Buffer_Type_F8E4M3FN, + PJRT_Buffer_Type_F8E4M3B11FNUZ, + PJRT_Buffer_Type_F8E5M2FNUZ, + PJRT_Buffer_Type_F8E4M3FNUZ, + + // 4-bit integer types + PJRT_Buffer_Type_S4, + PJRT_Buffer_Type_U4, +} PJRT_Buffer_Type; + +typedef enum { + // The runtime may not hold references to `data` after the call to + // `PJRT_Client_BufferFromHostBuffer` completes. The caller promises that + // `data` is immutable and will not be freed only for the duration of the + // PJRT_Client_BufferFromHostBuffer call. + PJRT_HostBufferSemantics_kImmutableOnlyDuringCall, + + // The runtime may hold onto `data` after the call to + // `PJRT_Client_BufferFromHostBuffer` + // returns while the runtime completes a transfer to the device. The caller + // promises not to mutate or free `data` until the transfer completes, at + // which point `done_with_host_buffer` will be triggered. + PJRT_HostBufferSemantics_kImmutableUntilTransferCompletes, + + // The PjRtBuffer may alias `data` internally and the runtime may use the + // `data` contents as long as the buffer is alive. The caller promises to + // keep `data` alive and not to mutate its contents as long as the buffer is + // alive; to notify the caller that the buffer may be freed, the runtime + // will call `done_with_host_buffer` when the PjRtBuffer is freed. + PJRT_HostBufferSemantics_kZeroCopy, +} PJRT_HostBufferSemantics; + +typedef enum { + PJRT_Buffer_MemoryLayout_Type_Tiled = 0, + PJRT_Buffer_MemoryLayout_Type_Strides, +} PJRT_Buffer_MemoryLayout_Type; + +struct PJRT_Buffer_MemoryLayout_Tiled { + size_t struct_size; + void* priv; + // A map from physical dimension numbers to logical dimension numbers. + // The first element is the most minor physical dimension (fastest varying + // index) and the last the most major (slowest varying index). The contents of + // the vector are the indices of the *logical* dimensions in the shape. Must + // be the same size as the number of dimensions of the buffer. + const int64_t* minor_to_major; + size_t minor_to_major_size; + // A concatenated list of tile dimensions. + const int64_t* tile_dims; + // The list of tile dimension sizes. The size of this list is `num_tiles`. + const size_t* tile_dim_sizes; + size_t num_tiles; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout_Tiled, num_tiles); + +struct PJRT_Buffer_MemoryLayout_Strides { + size_t struct_size; + void* priv; + // Number of bytes to traverse per dimension. Must be the same size as + // the number of dimensions of the data. Caution: `byte_strides` are allowed + // to be negative, in which case data may need to point to the interior of + // the buffer, not necessarily its start. + const int64_t* byte_strides; + size_t num_byte_strides; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout_Strides, num_byte_strides); + +// Describe the memory layout. It can be (1) a list of minor-to-major order and +// optional tilings (each tile is a list of dimensions), or (2) a list of +// strides. +struct PJRT_Buffer_MemoryLayout { + size_t struct_size; + void* priv; + union { + PJRT_Buffer_MemoryLayout_Tiled tiled; + PJRT_Buffer_MemoryLayout_Strides strides; + }; + PJRT_Buffer_MemoryLayout_Type type; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_MemoryLayout, type); + +struct PJRT_Client_BufferFromHostBuffer_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + // Pointer to the host buffer + const void* data; + // The type of the `data`, and the type of the resulting output `buffer` + PJRT_Buffer_Type type; + // The array dimensions of `data`. + const int64_t* dims; + size_t num_dims; + + // Number of bytes to traverse per dimension of the input data. Must be the + // same size as `dims`, or empty. If empty, the array is assumed to have a + // dense layout with dimensions in major-to-minor order + // Caution: `byte_strides` are allowed to be negative, in which case `data` + // may need to point to the interior of the buffer, not necessarily its start. + const int64_t* byte_strides; + size_t num_byte_strides; + + PJRT_HostBufferSemantics host_buffer_semantics; + + // Device to copy host data to. + PJRT_Device* device; + + // If nullptr, host data will be copied to `device`, otherwise we copy data to + // `memory`. + PJRT_Memory* memory; + + // The caller is responsible to keep the data (tiled or strides) in the + // device_layout alive during the call. If nullptr, the device layout is + // assumed to be a dense layout with dimensions in major-to-minor order. + PJRT_Buffer_MemoryLayout* device_layout; + + // Event indicating when it's safe to free `data`. The caller is responsible + // for calling PJRT_Event_Destroy. + PJRT_Event* done_with_host_buffer; // out + + // Output device buffer. The caller is responsible for calling + // PJRT_Buffer_Destroy. + PJRT_Buffer* buffer; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_BufferFromHostBuffer_Args, buffer); + +// Asynchronously copies a buffer stored on host to device memory. +typedef PJRT_Error* PJRT_Client_BufferFromHostBuffer( + PJRT_Client_BufferFromHostBuffer_Args* args); + +struct PJRT_Client_CreateViewOfDeviceBuffer_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + // A pointer to a non-owned device buffer. A PJRT_Buffer that is a non-owned + // view of this device buffer will be created. + void* device_buffer_ptr; + const int64_t* dims; + size_t num_dims; + PJRT_Buffer_Type element_type; + PJRT_Buffer_MemoryLayout* layout; + // The device that `device_buffer_ptr` is on. + PJRT_Device* device; + // A callback to be performed when the PJRT_Buffer is done with the on-device + // buffer. This callback is optional and can be a nullptr. + void (*on_delete_callback)(void* device_buffer_ptr, void* user_arg); + // `on_delete_callback_arg` will be passed to `on_delete_callback` as + // `user_arg` argument. + void* on_delete_callback_arg; + // A platform-specific stream handle that should contain the work or events + // needed to materialize the on-device buffer. It is optional and can be + // casted from a nullptr. PJRT_Client_CreateViewOfDeviceBuffer_Args will + // append an event to `stream` that indicates when the returned buffer is + // ready to use. This is intended to support dlpack on GPU and is not expected + // to be supported on all hardware platforms. + intptr_t stream; + PJRT_Buffer* buffer; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateViewOfDeviceBuffer_Args, buffer); + +// Creates a PJRT buffer that is a non-owned view of an on-device buffer +// (typically allocated by another library). The buffer may be mutated, +// for example, if the buffer is donated to an Execute operation. This method is +// not required on all hardware platforms. +typedef PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer( + PJRT_Client_CreateViewOfDeviceBuffer_Args* args); + +// -------------------------- Device Descriptions ------------------------------ + +// Device descriptions may be associated with an actual device +// (via PJRT_Device_GetDescription), but they can also be used to describe a +// device that isn't currently available to the plugin. This is useful for +// compiling executables without hardware available, which can then be +// serialized and written somewhere durable, and then loaded and run on actual +// hardware later. + +struct PJRT_DeviceDescription_Id_Args { + size_t struct_size; + void* priv; + PJRT_DeviceDescription* device_description; + int id; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_Id_Args, id); + +// The ID of this device. IDs are unique among devices of this type +// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all +// hosts' devices. +typedef PJRT_Error* PJRT_DeviceDescription_Id( + PJRT_DeviceDescription_Id_Args* args); + +struct PJRT_DeviceDescription_ProcessIndex_Args { + size_t struct_size; + void* priv; + PJRT_DeviceDescription* device_description; + int process_index; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_ProcessIndex_Args, + process_index); + +// The index of the process that this device belongs to, i.e. is addressable +// from. This is not always identical to PJRT_Client_ProcessIndex in a +// multi-process setting, where each client can see devices from all +// processes, but only a subset of them are addressable and have the same +// process_index as the client. +typedef PJRT_Error* PJRT_DeviceDescription_ProcessIndex( + PJRT_DeviceDescription_ProcessIndex_Args* args); + +struct PJRT_DeviceDescription_Attributes_Args { + size_t struct_size; + void* priv; + PJRT_DeviceDescription* device_description; + size_t num_attributes; // out + const PJRT_NamedValue* attributes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_Attributes_Args, attributes); + +// Returns an array of device specific attributes with attribute name, value +// and value type. +typedef PJRT_Error* PJRT_DeviceDescription_Attributes( + PJRT_DeviceDescription_Attributes_Args* args); + +struct PJRT_DeviceDescription_Kind_Args { + size_t struct_size; + void* priv; + PJRT_DeviceDescription* device_description; + // `device_kind` string is owned by `device` and has same lifetime as + // `device`. + const char* device_kind; // out + size_t device_kind_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_Kind_Args, device_kind_size); + +// A vendor-dependent string that uniquely identifies the kind of device, +// e.g., "Tesla V100-SXM2-16GB". +typedef PJRT_Error* PJRT_DeviceDescription_Kind( + PJRT_DeviceDescription_Kind_Args* args); + +struct PJRT_DeviceDescription_DebugString_Args { + size_t struct_size; + void* priv; + PJRT_DeviceDescription* device_description; + const char* debug_string; // out + size_t debug_string_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_DebugString_Args, + debug_string_size); + +// Debug string suitable for logging when errors occur. Should be verbose +// enough to describe the current device unambiguously. +typedef PJRT_Error* PJRT_DeviceDescription_DebugString( + PJRT_DeviceDescription_DebugString_Args* args); + +struct PJRT_DeviceDescription_ToString_Args { + size_t struct_size; + void* priv; + PJRT_DeviceDescription* device_description; + const char* to_string; // out + size_t to_string_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_ToString_Args, to_string_size); + +// Debug string suitable for reading by end users, should be reasonably terse, +// for example: "CpuDevice(id=0)". +typedef PJRT_Error* PJRT_DeviceDescription_ToString( + PJRT_DeviceDescription_ToString_Args* args); + +// --------------------------------- Devices ----------------------------------- + +struct PJRT_Device_GetDescription_Args { + size_t struct_size; + void* priv; + PJRT_Device* device; + PJRT_DeviceDescription* device_description; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_GetDescription_Args, device_description); + +// Fetch the DeviceDescription associated with this device. +typedef PJRT_Error* PJRT_Device_GetDescription( + PJRT_Device_GetDescription_Args* args); + +struct PJRT_Device_IsAddressable_Args { + size_t struct_size; + void* priv; + PJRT_Device* device; + bool is_addressable; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_IsAddressable_Args, is_addressable); + +// Whether client can issue command to this device. +typedef PJRT_Error* PJRT_Device_IsAddressable( + PJRT_Device_IsAddressable_Args* args); + +struct PJRT_Device_LocalHardwareId_Args { + size_t struct_size; + void* priv; + PJRT_Device* device; + int local_hardware_id; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_LocalHardwareId_Args, local_hardware_id); + +// Opaque hardware ID, e.g., the CUDA device number. In general, not guaranteed +// to be dense, and -1 if undefined. +typedef PJRT_Error* PJRT_Device_LocalHardwareId( + PJRT_Device_LocalHardwareId_Args* args); + +struct PJRT_Device_AddressableMemories_Args { + size_t struct_size; + void* priv; + PJRT_Device* device; + // Has the lifetime of `device`. + PJRT_Memory* const* memories; // out + size_t num_memories; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_AddressableMemories_Args, memories); + +// Returns the memories that a device can address. +typedef PJRT_Error* PJRT_Device_AddressableMemories( + PJRT_Device_AddressableMemories_Args* args); + +struct PJRT_Device_DefaultMemory_Args { + size_t struct_size; + void* priv; + PJRT_Device* device; + // `memory` has the same lifetime as `device`. + PJRT_Memory* memory; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_DefaultMemory_Args, memory); + +// Returns the default memory of a device, i.e. which memory data processed by +// this device should be stored in by default. +typedef PJRT_Error* PJRT_Device_DefaultMemory( + PJRT_Device_DefaultMemory_Args* args); + +struct PJRT_Device_MemoryStats_Args { + size_t struct_size; + void* priv; + PJRT_Device* device; + + // Number of bytes in use. + int64_t bytes_in_use; // out + + // The peak bytes in use. + int64_t peak_bytes_in_use; // out + bool peak_bytes_in_use_is_set; // out + // Number of allocations. + int64_t num_allocs; // out + bool num_allocs_is_set; // out + // The largest single allocation seen. + int64_t largest_alloc_size; // out + bool largest_alloc_size_is_set; // out + // The upper limit of user-allocatable device memory in bytes. + int64_t bytes_limit; // out + bool bytes_limit_is_set; // out + + // Number of bytes reserved. + int64_t bytes_reserved; // out + bool bytes_reserved_is_set; // out + // The peak number of bytes reserved. + int64_t peak_bytes_reserved; // out + bool peak_bytes_reserved_is_set; // out + // The upper limit on the number bytes of reservable memory. + int64_t bytes_reservable_limit; // out + bool bytes_reservable_limit_is_set; // out + + // Largest free block size in bytes. + int64_t largest_free_block_bytes; // out + bool largest_free_block_bytes_is_set; // out + + // Number of bytes of memory held by the allocator. This may be higher than + // bytes_in_use if the allocator holds a pool of memory (e.g. BFCAllocator). + int64_t pool_bytes; // out + bool pool_bytes_is_set; // out + int64_t peak_pool_bytes; // out + bool peak_pool_bytes_is_set; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_MemoryStats_Args, peak_pool_bytes_is_set); + +// Device memory/allocator statistics. All returned stats except `bytes_in_use` +// are optional and may not be returned by all platforms. Implementations may +// also return PJRT_Error_Code_UNIMPLEMENTED. Intended for diagnostic purposes. +typedef PJRT_Error* PJRT_Device_MemoryStats(PJRT_Device_MemoryStats_Args* args); + +//-------------------------------- Memory -------------------------------------- + +struct PJRT_Memory_Id_Args { + size_t struct_size; + void* priv; + PJRT_Memory* memory; + int id; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_Id_Args, id); + +// The ID of this memory. IDs are unique among memories of this type. +typedef PJRT_Error* PJRT_Memory_Id(PJRT_Memory_Id_Args* args); + +struct PJRT_Memory_Kind_Args { + size_t struct_size; + void* priv; + PJRT_Memory* memory; + // `memory_kind` has same lifetime as `memory`. + const char* memory_kind; // out + size_t memory_kind_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_Kind_Args, memory_kind_size); + +// A platform-dependent string that uniquely identifies the kind of the memory. +typedef PJRT_Error* PJRT_Memory_Kind(PJRT_Memory_Kind_Args* args); + +struct PJRT_Memory_DebugString_Args { + size_t struct_size; + void* priv; + PJRT_Memory* memory; + const char* debug_string; // out + size_t debug_string_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_DebugString_Args, debug_string_size); + +// Debug string suitable for logging when errors occur. Should be verbose +// enough to describe the current memory unambiguously. +typedef PJRT_Error* PJRT_Memory_DebugString(PJRT_Memory_DebugString_Args* args); + +struct PJRT_Memory_ToString_Args { + size_t struct_size; + void* priv; + PJRT_Memory* memory; + const char* to_string; // out + size_t to_string_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_ToString_Args, to_string_size); + +// Debug string suitable for reading by end users, should be reasonably terse. +typedef PJRT_Error* PJRT_Memory_ToString(PJRT_Memory_ToString_Args* args); + +struct PJRT_Memory_AddressableByDevices_Args { + size_t struct_size; + void* priv; + PJRT_Memory* memory; + PJRT_Device* const* devices; // out + size_t num_devices; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_AddressableByDevices_Args, num_devices); + +// Returns the devices that can address this memory. +typedef PJRT_Error* PJRT_Memory_AddressableByDevices( + PJRT_Memory_AddressableByDevices_Args* args); + +// ------------------------------- Executables --------------------------------- + +struct PJRT_Executable_Destroy_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Destroy_Args, executable); + +// Frees `executable`. `executable` can be nullptr. +typedef PJRT_Error* PJRT_Executable_Destroy(PJRT_Executable_Destroy_Args* args); + +struct PJRT_LoadedExecutable_Destroy_Args { + size_t struct_size; + void* priv; + PJRT_LoadedExecutable* executable; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Destroy_Args, executable); + +// Frees `executable` and deletes the underlying runtime object as if +// `PJRT_LoadedExecutable_Delete` were called. `executable` can be nullptr. +typedef PJRT_Error* PJRT_LoadedExecutable_Destroy( + PJRT_LoadedExecutable_Destroy_Args* args); + +struct PJRT_LoadedExecutable_GetExecutable_Args { + size_t struct_size; + void* priv; + PJRT_LoadedExecutable* loaded_executable; + PJRT_Executable* executable; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_GetExecutable_Args, executable); + +// Constructs a PJRT_Executable from a PJRT_LoadedExecutable. The returned +// executable should be freed by the caller with PJRT_Executable_Destroy. +typedef PJRT_Error* PJRT_LoadedExecutable_GetExecutable( + PJRT_LoadedExecutable_GetExecutable_Args* args); + +struct PJRT_Executable_Name_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + // `executable_name` has the same lifetime as `executable`. It is owned by + // `executable`. + const char* executable_name; // out + size_t executable_name_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Name_Args, executable_name_size); + +// Returns a string that identifies the executable. +typedef PJRT_Error* PJRT_Executable_Name(PJRT_Executable_Name_Args* args); + +// TODO(b/269178731): Revisit whether num_replicas is needed. +struct PJRT_Executable_NumReplicas_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + size_t num_replicas; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_NumReplicas_Args, num_replicas); + +// Returns the number of replicas of the executable. +typedef PJRT_Error* PJRT_Executable_NumReplicas( + PJRT_Executable_NumReplicas_Args* args); + +struct PJRT_Executable_NumPartitions_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + size_t num_partitions; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_NumPartitions_Args, num_partitions); + +// Returns the number of partitions of the executable. +typedef PJRT_Error* PJRT_Executable_NumPartitions( + PJRT_Executable_NumPartitions_Args* args); + +struct PJRT_LoadedExecutable_AddressableDevices_Args { + size_t struct_size; + void* priv; + PJRT_LoadedExecutable* executable; + PJRT_Device* const* addressable_devices; // out + size_t num_addressable_devices; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_AddressableDevices_Args, + num_addressable_devices); + +// Returns a list of devices this executable will run on. +typedef PJRT_Error* PJRT_LoadedExecutable_AddressableDevices( + PJRT_LoadedExecutable_AddressableDevices_Args* args); + +struct PJRT_Executable_OptimizedProgram_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + PJRT_Program* program; // out, but read below +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_OptimizedProgram_Args, program); + +// Retrieves the optimized program for a given PJRT_Executable (SPMD). +// The caller should populate `program->format` and `format_size`. +// +// The implementation will set `program->format` and `program->format_size` +// to inform callers of the format of the optimized program returned. +// These members are owned by the implementation. +// +// If called with nullptr as `program->code`, `PJRT_Executable_OptimizedProgram` +// will populate `program->code_size` as an output indicating the number of +// bytes the string `program->code` requires. +// +// If `program->code` is not null, `PJRT_Executable_OptimizedProgram` will fill +// the buffer pointed to by `program->code` with the serialization of the +// optimized HLO program. `program->code` must point to a client-owned buffer of +// size >= `program->code_size`, which must be at large enough to hold the +// serialization of the optimized program. +// +// Callers should generally call this function twice with the same `args`. +// In the first call, `program->code` must be nullptr. This call will populate +// `program->code_size`. Clients should then allocate a buffer `code_buff` of at +// least `code_size` bytes. Before the second call, callers should set +// `program->code = code_buff`. The second call will then write the serialized +// program to `code_buff`. +typedef PJRT_Error* PJRT_Executable_OptimizedProgram( + PJRT_Executable_OptimizedProgram_Args* args); + +struct PJRT_LoadedExecutable_Delete_Args { + size_t struct_size; + void* priv; + PJRT_LoadedExecutable* executable; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Delete_Args, executable); + +// Drops `executable`'s reference to the internal runtime object and +// associated resources, without freeing the `executable` object itself. +// `executable` can only be used with PJRT_LoadedExecutable_IsDeleted and +// PJRT_LoadedExecutable_Destroy after calling this method. The internal runtime +// executable will be freed after the last execution completes. +typedef PJRT_Error* PJRT_LoadedExecutable_Delete( + PJRT_LoadedExecutable_Delete_Args* args); + +struct PJRT_LoadedExecutable_IsDeleted_Args { + size_t struct_size; + void* priv; + PJRT_LoadedExecutable* executable; + bool is_deleted; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_IsDeleted_Args, is_deleted); + +// True if and only if PJRT_LoadedExecutable_Delete has previously been called. +typedef PJRT_Error* PJRT_LoadedExecutable_IsDeleted( + PJRT_LoadedExecutable_IsDeleted_Args* args); + +typedef struct PJRT_Chunk { + void* data; + size_t size; + void (*deleter)(void* data, void* deleter_arg); + // `deleter_arg` will be passed to `deleter` as `deleter_arg` argument. + void* deleter_arg; +} PJRT_Chunk; + +// TODO(b/263390934) implement C API that calls `AddChunk` and other +// `xla::CopyToDeviceStream`. +typedef struct PJRT_CopyToDeviceStream PJRT_CopyToDeviceStream; + +struct PJRT_TransferMetadata; + +// Returns PJRT_Error* created by PJRT_CallbackError in case of error. +// Otherwise, returns nullptr. The callback must call +// `chunk->deleter(chunk->data, chunk->deleter_arg)` when it's finished with +// `chunk`. +typedef PJRT_Error* (*PJRT_SendCallback)(PJRT_Chunk* chunk, + PJRT_CallbackError* callback_error, + size_t total_size_in_bytes, bool done, + void* user_arg); +// The callback takes the ownership of the stream object. The callback must call +// `PJRT_CopyToDeviceStream_Destroy` when it is done with the stream. +typedef void (*PJRT_RecvCallback)(PJRT_CopyToDeviceStream* stream, + void* user_arg); + +struct PJRT_SendCallbackInfo { + // Used to associate this callback with the correct send op. + int64_t channel_id; + // Will be passed to `send_callback` as `user_arg` argument. + void* user_arg; + PJRT_SendCallback send_callback; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_SendCallbackInfo, send_callback); + +struct PJRT_RecvCallbackInfo { + // Used to associate this callback with the correct recv op. + int64_t channel_id; + // Will be passed to `recv_callback` as `user_arg` argument. + void* user_arg; + PJRT_RecvCallback recv_callback; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_RecvCallbackInfo, recv_callback); + +struct PJRT_ExecuteOptions { + size_t struct_size; + void* priv; + // Callbacks for when send/recv ops are executed. The outer lists correspond + // to each device returned by `PJRT_Executable_AddressableDevices` for + // `executable` (i.e. they will have length `num_devices`). Each inner list + // contains callback info for each send/recv op in `executable`; the order + // doesn't matter as the channel IDs are used instead. The callbacks can be + // stateful and the user code is responsible for managing state. The callback + // functions must outlive the execution (but not the info structs or lists). + PJRT_SendCallbackInfo** send_callbacks; + PJRT_RecvCallbackInfo** recv_callbacks; + size_t num_send_ops; + size_t num_recv_ops; + // If non-zero, identifies this execution as part of a potentially + // multi-device launch. This can be used to detect scheduling errors, e.g. if + // multi-host programs are launched in different orders on different hosts, + // the launch IDs may be used by the runtime to detect the mismatch. + int launch_id; + // A list of indices denoting the input buffers that should not be donated. + // An input buffer may be non-donable, for example, if it is referenced more + // than once. Since such runtime information is not available at compile time, + // the compiler might mark the input as `may-alias`, which could lead PjRt to + // donate the input buffer when it should not. By defining this list of + // indices, a higher-level PJRT caller can instruct PJRT client not to donate + // specific input buffers. The caller needs to make sure to keep it alive + // during the call. + const int64_t* non_donatable_input_indices; + size_t num_non_donatable_input_indices; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_ExecuteOptions, launch_id); + +struct PJRT_LoadedExecutable_Execute_Args { + size_t struct_size; + void* priv; + PJRT_LoadedExecutable* executable; + // Only needs to stay alive for the duration of the Execute call. + PJRT_ExecuteOptions* options; + // Execution input of size [`num_devices`, `num_args`]. + PJRT_Buffer* const* const* argument_lists; + size_t num_devices; + size_t num_args; + // Execution output of size [`num_devices`, num_outputs`], where `num_outputs` + // is the number of outputs returned by this executable per device. Both the + // outer (`PJRT_Buffer***`) and inner lists (`PJRT_Buffer**`) must be + // allocated and deallocated by the caller. PJRT_Buffer_Destroy must be called + // on the output PJRT_Buffer*. + PJRT_Buffer** const* output_lists; // in/out + // If `device_complete_events` isn't nullptr, `device_complete_events` needs + // to be the same length as `output_lists` (i.e. of length `num_devices`), and + // each `PJRT_Event` will become ready once the corresponding device execution + // is complete. If Execute returns an error, then `device_complete_events` + // will not be populated. The caller is responsible for calling + // PJRT_Event_Destroy on the returned PJRT_Event*s. + PJRT_Event** device_complete_events; // in/out + // The device to execute on. If nullptr, will execute on the device(s) + // specified at compile time. If set, must be an addressable device, and + // `num_devices` should be 1 with `argument_lists` only containing arguments + // for `execute_device`. Can be set with a multi-device executable to launch + // just on this device. In this case, it's the responsibility of the caller to + // make sure the executable is launched on all participating devices specified + // at compile time. Setting this field may not be supported on all platforms + // or executables. + PJRT_Device* execute_device; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Execute_Args, execute_device); + +// Executes on devices addressable by the client. +typedef PJRT_Error* PJRT_LoadedExecutable_Execute( + PJRT_LoadedExecutable_Execute_Args* args); + +struct PJRT_Executable_NumOutputs_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + size_t num_outputs; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_NumOutputs_Args, num_outputs); + +// Gets the number of outputs per device produced by `executable`. +typedef PJRT_Error* PJRT_Executable_NumOutputs( + PJRT_Executable_NumOutputs_Args* args); + +struct PJRT_Executable_SizeOfGeneratedCodeInBytes_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + int64_t size_in_bytes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_SizeOfGeneratedCodeInBytes_Args, + size_in_bytes); // last field in the struct + +typedef PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes( + PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args); + +struct PJRT_Executable_Fingerprint_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + // Has the lifetime of `executable` + const char* executable_fingerprint; // out + size_t executable_fingerprint_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Fingerprint_Args, + executable_fingerprint_size); + +// A unique fingerprint for `executable`. Two executables that were produced by +// compiling with identical inputs (same program, compile options, compiler +// version, etc.) should have the same fingerprint. May not be implemented by +// all platforms. +typedef PJRT_Error* PJRT_Executable_Fingerprint( + PJRT_Executable_Fingerprint_Args* args); + +struct PJRT_Executable_GetCostAnalysis_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + size_t num_properties; // out + // `properties` and any embedded data are owned by and have the same lifetime + // as `executable`. + const PJRT_NamedValue* properties; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_GetCostAnalysis_Args, properties); + +// Get the cost properties for the executable. Different platforms may return +// different properties; for example, some platforms may return the number of +// operations, or memory size of the input/output of the executable, based on +// program analysis. +typedef PJRT_Error* PJRT_Executable_GetCostAnalysis( + PJRT_Executable_GetCostAnalysis_Args* args); + +struct PJRT_Executable_GetCompiledMemoryStats_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + + // Mirrors xla::CompiledMemoryStats. + int64_t generated_code_size_in_bytes; // out + int64_t argument_size_in_bytes; // out + int64_t output_size_in_bytes; // out + // How much argument is reused for output. + int64_t alias_size_in_bytes; // out + int64_t temp_size_in_bytes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_GetCompiledMemoryStats_Args, + temp_size_in_bytes); + +// Return memory stats that allow callers to estimate device memory usage +// when running this executable. +typedef PJRT_Error* PJRT_Executable_GetCompiledMemoryStats( + PJRT_Executable_GetCompiledMemoryStats_Args* args); + +struct PJRT_Executable_OutputElementTypes_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + PJRT_Buffer_Type* output_types; // out + size_t num_output_types; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_OutputElementTypes_Args, + num_output_types); + +// Returns a list of element types for outputs. +typedef PJRT_Error* PJRT_Executable_OutputElementTypes( + PJRT_Executable_OutputElementTypes_Args* args); + +struct PJRT_Executable_OutputDimensions_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + size_t num_outputs; + // Has length: sum of all elements in the list `dim_sizes`. + const int64_t* dims; // out + // Has length `num_outputs`. + const size_t* dim_sizes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_OutputDimensions_Args, dim_sizes); + +// Returns a list of dimensions for outputs. Each output has an array shape, +// which is represented by a list of dimensions. The array shapes of all outputs +// are concatenated into a single list of dimensions. +typedef PJRT_Error* PJRT_Executable_OutputDimensions( + PJRT_Executable_OutputDimensions_Args* args); + +struct PJRT_Executable_OutputMemoryKinds_Args { + size_t struct_size; + void* priv; + PJRT_Executable* executable; + size_t num_outputs; + // Has length `num_outputs`. + const char* const* memory_kinds; // out + // Has length `num_outputs`. + const size_t* memory_kind_sizes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_OutputMemoryKinds_Args, + memory_kind_sizes); + +// Returns a list of memory kind strings for outputs. +typedef PJRT_Error* PJRT_Executable_OutputMemoryKinds( + PJRT_Executable_OutputMemoryKinds_Args* args); + +typedef struct PJRT_SerializedExecutable PJRT_SerializedExecutable; + +struct PJRT_Executable_Serialize_Args { + size_t struct_size; + void* priv; + const PJRT_Executable* executable; + + // Lives only as long as serialized_executable + const char* serialized_bytes; // out + size_t serialized_bytes_size; // out + + PJRT_SerializedExecutable* serialized_executable; // backs serialized_bytes. + // cleanup fn must be called to free the backing memory for serialized_bytes. + // Should only be called once on serialized_executable. + void (*serialized_executable_deleter)( + PJRT_SerializedExecutable* exec); // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Serialize_Args, + serialized_executable_deleter); + +// Returns a platform-specific serialization of `executable`. The serialization +// is not guaranteed to be stable over time. +typedef PJRT_Error* PJRT_Executable_Serialize( + PJRT_Executable_Serialize_Args* args); + +struct PJRT_Executable_DeserializeAndLoad_Args { + size_t struct_size; + void* priv; + PJRT_Client* client; + const char* serialized_executable; + size_t serialized_executable_size; + PJRT_LoadedExecutable* loaded_executable; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_DeserializeAndLoad_Args, + loaded_executable); + +// Deserializes an executable serialized by `PJRT_Executable_Serialize`. +// `serialized_executable` must have been produced by the same platform and +// library version as this one. +typedef PJRT_Error* PJRT_Executable_DeserializeAndLoad( + PJRT_Executable_DeserializeAndLoad_Args* args); + +struct PJRT_LoadedExecutable_Fingerprint_Args { + size_t struct_size; + void* priv; + PJRT_LoadedExecutable* executable; + // Has the lifetime of `executable` + const char* executable_fingerprint; // out + size_t executable_fingerprint_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Fingerprint_Args, + executable_fingerprint_size); +// DEPRECATED. Will be removed in PJRT version 2.0. Please use +// PJRT_Executable_Fingerprint instead. A unique fingerprint for `executable`. +// Two executables that were produced by compiling with identical inputs (same +// program, compile options, compiler version, etc.) should have the same +// fingerprint. May not be implemented by all platforms. +typedef PJRT_Error* PJRT_LoadedExecutable_Fingerprint( + PJRT_LoadedExecutable_Fingerprint_Args* args); + +// ---------------------------------- Buffers ---------------------------------- + +struct PJRT_Buffer_Destroy_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Destroy_Args, buffer); + +// Deletes the underlying runtime objects as if 'PJRT_Buffer_Delete' were +// called and frees `buffer`. `buffer` can be nullptr. +typedef PJRT_Error* PJRT_Buffer_Destroy(PJRT_Buffer_Destroy_Args* args); + +struct PJRT_Buffer_ElementType_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + PJRT_Buffer_Type type; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_ElementType_Args, type); + +// Returns the type of the array elements of a buffer. +typedef PJRT_Error* PJRT_Buffer_ElementType(PJRT_Buffer_ElementType_Args* args); + +struct PJRT_Buffer_Dimensions_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + // Has the lifetime of `buffer` and length `num_dims`. + const int64_t* dims; // out + size_t num_dims; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Dimensions_Args, num_dims); + +// Returns the array shape of `buffer`, i.e. the size of each dimension. +typedef PJRT_Error* PJRT_Buffer_Dimensions(PJRT_Buffer_Dimensions_Args* args); + +struct PJRT_Buffer_UnpaddedDimensions_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + // Has the lifetime of `buffer` and length `num_dims`. + const int64_t* unpadded_dims; // out + size_t num_dims; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_UnpaddedDimensions_Args, num_dims); + +// Returns the unpadded array shape of `buffer`. This usually is equivalent to +// PJRT_Buffer_Dimensions, but for implementations that support +// dynamically-sized dimensions via padding to a fixed size, any dynamic +// dimensions may have a smaller unpadded size than the padded size reported by +// PJRT_Buffer_Dimensions. ("Dynamic" dimensions are those whose length is +// only known at runtime, vs. "static" dimensions whose size is fixed at compile +// time.) +typedef PJRT_Error* PJRT_Buffer_UnpaddedDimensions( + PJRT_Buffer_UnpaddedDimensions_Args* args); + +struct PJRT_Buffer_DynamicDimensionIndices_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + // Has the lifetime of `buffer` and length `num_dynamic_dims`. + const size_t* dynamic_dim_indices; // out + size_t num_dynamic_dims; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_DynamicDimensionIndices_Args, + num_dynamic_dims); + +// Returns the indices of dynamically-sized dimensions, or an empty list if all +// dimensions are static. ("Dynamic" dimensions are those whose length is +// only known at runtime, vs. "static" dimensions whose size is fixed at compile +// time.) +typedef PJRT_Error* PJRT_Buffer_DynamicDimensionIndices( + PJRT_Buffer_DynamicDimensionIndices_Args* args); + +struct PJRT_Buffer_GetMemoryLayout_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + // Layout data is owned by and has the lifetime of `buffer`. + PJRT_Buffer_MemoryLayout layout; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_GetMemoryLayout_Args, layout); + +// Returns the memory layout of the data in this buffer. +typedef PJRT_Error* PJRT_Buffer_GetMemoryLayout( + PJRT_Buffer_GetMemoryLayout_Args* args); + +struct PJRT_Buffer_ToHostBuffer_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* src; + + // The caller can specify an optional host layout. If nullptr, the layout of + // the src buffer will be used. The caller is responsible to keep the data + // (tiled or strides) in the host_layout alive during the call. + PJRT_Buffer_MemoryLayout* host_layout; + // `dst` can be nullptr to query required size which will be set into + // `dst_size`. + void* dst; // in/out + // Size of `dst` in bytes. If `dst` is nullptr, then `dst_size` is set to the + // size needed. Otherwise, `dst_size` must be greater than or equal to the + // needed size. + size_t dst_size; // in/out + + // Event that signals when the copy has completed. + PJRT_Event* event; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_ToHostBuffer_Args, event); + +// Asynchronously copies the buffer's value into a preallocated host buffer. +typedef PJRT_Error* PJRT_Buffer_ToHostBuffer( + PJRT_Buffer_ToHostBuffer_Args* args); + +struct PJRT_Buffer_OnDeviceSizeInBytes_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + size_t on_device_size_in_bytes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_OnDeviceSizeInBytes_Args, + on_device_size_in_bytes); + +// Gets the number of bytes of the buffer storage on the device +typedef PJRT_Error* PJRT_Buffer_OnDeviceSizeInBytes( + PJRT_Buffer_OnDeviceSizeInBytes_Args* args); + +struct PJRT_Buffer_Delete_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Delete_Args, buffer); + +// Drop the buffer's reference to its associated device memory, without freeing +// the `buffer` object itself. `buffer` can only be used with +// PJRT_Buffer_IsDeleted and PJRT_Buffer_Destroy after calling this method. The +// device memory will be freed when all async operations using the buffer have +// completed, according to the allocation semantics of the underlying platform. +typedef PJRT_Error* PJRT_Buffer_Delete(PJRT_Buffer_Delete_Args* args); + +struct PJRT_Buffer_IsDeleted_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + bool is_deleted; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_IsDeleted_Args, is_deleted); + +// True if and only if PJRT_Buffer_Delete has previously been called. +typedef PJRT_Error* PJRT_Buffer_IsDeleted(PJRT_Buffer_IsDeleted_Args* args); + +struct PJRT_Buffer_CopyToDevice_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + PJRT_Device* dst_device; + PJRT_Buffer* dst_buffer; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_CopyToDevice_Args, dst_buffer); + +// Copies the buffer to device `dst_device` within the same client. Caller is +// responsible for freeing returned `dst_buffer` with PJRT_Buffer_Destroy. +// Returns an error if the buffer is already on `dst_device`. +typedef PJRT_Error* PJRT_Buffer_CopyToDevice( + PJRT_Buffer_CopyToDevice_Args* args); + +struct PJRT_Buffer_CopyToMemory_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + PJRT_Memory* dst_memory; + PJRT_Buffer* dst_buffer; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_CopyToMemory_Args, dst_buffer); + +// Copies the buffer to memory `dst_memory` within the same client. Caller is +// responsible for freeing returned `dst_buffer` with PJRT_Buffer_Destroy. +// Returns an error if the buffer is already on `dst_memory`. +typedef PJRT_Error* PJRT_Buffer_CopyToMemory( + PJRT_Buffer_CopyToMemory_Args* args); + +struct PJRT_Buffer_IsOnCpu_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + bool is_on_cpu; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_IsOnCpu_Args, is_on_cpu); + +// Whether this buffer is on CPU and thus allows for certain optimizations. +typedef PJRT_Error* PJRT_Buffer_IsOnCpu(PJRT_Buffer_IsOnCpu_Args* args); + +struct PJRT_Buffer_Device_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + PJRT_Device* device; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Device_Args, device); + +// Returns this buffer's storage device. +typedef PJRT_Error* PJRT_Buffer_Device(PJRT_Buffer_Device_Args* args); + +struct PJRT_Buffer_Memory_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + PJRT_Memory* memory; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_Memory_Args, memory); + +// Returns this buffer's storage memory. +typedef PJRT_Error* PJRT_Buffer_Memory(PJRT_Buffer_Memory_Args* args); + +struct PJRT_Buffer_ReadyEvent_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + // The caller is responsible for calling PJRT_Event_Destroy on `event`. + PJRT_Event* event; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_ReadyEvent_Args, event); + +// Returns an event that is triggered when either of the following happens: +// * the data in the PJRT_Buffer becomes ready, or +// * an error has occurred. +// +// TODO(b/241967811): change these weird semantics +// If the buffer has been deleted or donated, the returned event will +// immediately indicate an error. However, if PJRT_Buffer_ReadyEvent() is +// called on the buffer before PJRT_Buffer_Delete() is, the returned event will +// not transition to an error state after PJRT_Buffer_Delete() is called. +typedef PJRT_Error* PJRT_Buffer_ReadyEvent(PJRT_Buffer_ReadyEvent_Args* args); + +struct PJRT_Buffer_UnsafePointer_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + uintptr_t buffer_pointer; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_UnsafePointer_Args, buffer_pointer); + +// Returns platform-dependent address for the given buffer that is often but +// not guaranteed to be the physical/device address. +typedef PJRT_Error* PJRT_Buffer_UnsafePointer( + PJRT_Buffer_UnsafePointer_Args* args); + +struct PJRT_Buffer_IncreaseExternalReferenceCount_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_IncreaseExternalReferenceCount_Args, + buffer); + +// Increments the reference count for the buffer. The reference count indicates +// the raw buffer data is being shared with another framework (e.g. NumPy, +// dlpack) and should not be deleted or moved by the PJRT implementation (e.g. +// for memory compaction). TODO(b/295230663): document more API contract +// details, e.g. does this block, can the buffer be modified in-place. +typedef PJRT_Error* PJRT_Buffer_IncreaseExternalReferenceCount( + PJRT_Buffer_IncreaseExternalReferenceCount_Args* args); + +struct PJRT_Buffer_DecreaseExternalReferenceCount_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_DecreaseExternalReferenceCount_Args, + buffer); + +// Decrements the reference count for the buffer. Returns an error if the +// reference count is zero (i.e. PJRT_Buffer_IncreaseExternalReferenceCount is +// not called beforehand). +typedef PJRT_Error* PJRT_Buffer_DecreaseExternalReferenceCount( + PJRT_Buffer_DecreaseExternalReferenceCount_Args* args); + +struct PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args { + size_t struct_size; + void* priv; + PJRT_Buffer* buffer; + void* device_memory_ptr; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args, + device_memory_ptr); + +// Returns the opaque device memory data pointer of the buffer. The returned +// data pointer may become invalid at any point unless the external reference +// count is greater than 0 via PJRT_Buffer_IncreaseExternalReferenceCount. +typedef PJRT_Error* PJRT_Buffer_OpaqueDeviceMemoryDataPointer( + PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args* args); + +// ---------------------------- CopyToDeviceStream ----------------------------- + +struct PJRT_CopyToDeviceStream_Destroy_Args { + size_t struct_size; + void* priv; + PJRT_CopyToDeviceStream* stream; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_Destroy_Args, stream); + +// Frees `stream`. `stream` can be nullptr. +typedef PJRT_Error* PJRT_CopyToDeviceStream_Destroy( + PJRT_CopyToDeviceStream_Destroy_Args* args); + +struct PJRT_CopyToDeviceStream_AddChunk_Args { + size_t struct_size; + void* priv; + PJRT_CopyToDeviceStream* stream; + // Takes ownership of `chunk` (i.e. implementation will call chunk.deleter). + PJRT_Chunk* chunk; + PJRT_Event* transfer_complete; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_AddChunk_Args, + transfer_complete); + +// Emplaces a new chunk of data to copy to the device. The transfer is started +// immediately, and the returned event is triggered when the transfer completes +// or fails. +// +// The returned event will indicate an error if the chunk's size causes the +// amount of transferred data to exceed the total bytes, if the stream is +// already complete, or if the chunk is not a multiple of the granule size. +typedef PJRT_Error* PJRT_CopyToDeviceStream_AddChunk( + PJRT_CopyToDeviceStream_AddChunk_Args* args); + +struct PJRT_CopyToDeviceStream_TotalBytes_Args { + size_t struct_size; + void* priv; + PJRT_CopyToDeviceStream* stream; + int64_t total_bytes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_TotalBytes_Args, total_bytes); + +// Returns the total amount of data the stream expects to be transferred. +typedef PJRT_Error* PJRT_CopyToDeviceStream_TotalBytes( + PJRT_CopyToDeviceStream_TotalBytes_Args* args); + +struct PJRT_CopyToDeviceStream_GranuleSize_Args { + size_t struct_size; + void* priv; + PJRT_CopyToDeviceStream* stream; + int64_t granule_size_in_bytes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_GranuleSize_Args, + granule_size_in_bytes); + +// Returns the granule size in bytes. The size of the chunk added to this stream +// must be a multiple of this number. +typedef PJRT_Error* PJRT_CopyToDeviceStream_GranuleSize( + PJRT_CopyToDeviceStream_GranuleSize_Args* args); + +struct PJRT_CopyToDeviceStream_CurrentBytes_Args { + size_t struct_size; + void* priv; + PJRT_CopyToDeviceStream* stream; + int64_t current_bytes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_CopyToDeviceStream_CurrentBytes_Args, + current_bytes); + +// Returns the amount of data the stream currently has either transferred or has +// buffered to transfer. +typedef PJRT_Error* PJRT_CopyToDeviceStream_CurrentBytes( + PJRT_CopyToDeviceStream_CurrentBytes_Args* args); + +// ------------------------------ Device Topology ------------------------------ + +struct PJRT_TopologyDescription_Create_Args { + size_t struct_size; + void* priv; + const char* topology_name; + size_t topology_name_size; + // Extra platform-specific options to create a client. + const PJRT_NamedValue* create_options; + size_t num_options; + PJRT_TopologyDescription* topology; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Create_Args, topology); + +// Creates and initializes a new PJRT_TopologyDescription and returns in +// `topology`. +typedef PJRT_Error* PJRT_TopologyDescription_Create( + PJRT_TopologyDescription_Create_Args* args); + +struct PJRT_TopologyDescription_Destroy_Args { + size_t struct_size; + void* priv; + PJRT_TopologyDescription* topology; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Destroy_Args, topology); + +// Frees `topology`. `topology` can be nullptr. +typedef PJRT_Error* PJRT_TopologyDescription_Destroy( + PJRT_TopologyDescription_Destroy_Args* args); + +struct PJRT_TopologyDescription_PlatformVersion_Args { + size_t struct_size; + void* priv; + PJRT_TopologyDescription* topology; + // `platform_version` has the same lifetime as `topology`. It's owned by + // `topology`. + const char* platform_version; // out + size_t platform_version_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_PlatformVersion_Args, + platform_version_size); + +// Returns a string containing human-readable, platform-specific version info +// (e.g. the CUDA version on GPU or libtpu version on Cloud TPU). +typedef PJRT_Error* PJRT_TopologyDescription_PlatformVersion( + PJRT_TopologyDescription_PlatformVersion_Args* args); + +struct PJRT_TopologyDescription_PlatformName_Args { + size_t struct_size; + void* priv; + PJRT_TopologyDescription* topology; + // `platform_name` has the same lifetime as `topology`. It is owned by + // `topology`. + const char* platform_name; // out + size_t platform_name_size; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_PlatformName_Args, + platform_name_size); + +// Returns a string that identifies the platform (e.g. "cpu", "gpu", "tpu"). +typedef PJRT_Error* PJRT_TopologyDescription_PlatformName( + PJRT_TopologyDescription_PlatformName_Args* args); + +struct PJRT_TopologyDescription_GetDeviceDescriptions_Args { + size_t struct_size; + void* priv; + PJRT_TopologyDescription* topology; + // Has the same lifetime as topology. + PJRT_DeviceDescription* const* descriptions; // out + size_t num_descriptions; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_GetDeviceDescriptions_Args, + num_descriptions); + +// Returns descriptions for all devices in this topology. The device +// descriptions can be returned in any order, but will be in the same order +// across calls within a process. +typedef PJRT_Error* PJRT_TopologyDescription_GetDeviceDescriptions( + PJRT_TopologyDescription_GetDeviceDescriptions_Args* args); + +typedef struct PJRT_SerializedTopology PJRT_SerializedTopology; + +struct PJRT_TopologyDescription_Serialize_Args { + size_t struct_size; + void* priv; + PJRT_TopologyDescription* topology; + + // Lives only as long as serialized_topology. + const char* serialized_bytes; // out + size_t serialized_bytes_size; // out + + PJRT_SerializedTopology* serialized_topology; // out + // Must be called exactly once to free the backing memory for + // serialized_bytes. + void (*serialized_topology_deleter)( + PJRT_SerializedTopology* serialized_topology); // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Serialize_Args, + serialized_topology_deleter); + +// Serializes the TopologyDescription to a string for use in cache keys. +typedef PJRT_Error* PJRT_TopologyDescription_Serialize( + PJRT_TopologyDescription_Serialize_Args* args); + +struct PJRT_TopologyDescription_Attributes_Args { + size_t struct_size; + void* priv; + PJRT_TopologyDescription* topology; + + // Only lives as long as topology. + const PJRT_NamedValue* attributes; // out + size_t num_attributes; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Attributes_Args, + num_attributes); + +// Returns platform-specific topology attributes. +typedef PJRT_Error* PJRT_TopologyDescription_Attributes( + PJRT_TopologyDescription_Attributes_Args* args); + +struct PJRT_Compile_Args { + size_t struct_size; + void* priv; + const PJRT_TopologyDescription* topology; + // Only needs to stay alive for the duration of the Compile call. + // `program->format` and `program->format_size` are owned by the caller. + const PJRT_Program* program; + // TODO(b/240560013): consider putting some of option fields in priv. + // Serialized CompileOptionsProto + // (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/compile_options.proto) + const char* compile_options; + size_t compile_options_size; + // Optionally provided for performance-guided optimizations. + PJRT_Client* client; + PJRT_Executable* executable; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Compile_Args, executable); + +// Compiles a program in specified format (such as MLIR or HLO) with given +// `options`. The returned executable must be loaded by a compatible +// PJRT_Client before execution. +typedef PJRT_Error* PJRT_Compile(PJRT_Compile_Args* args); + +// -------------------------------- Extension ---------------------------------- + +typedef enum { + PJRT_Structure_Type_Gpu_Custom_Call = 0, + PJRT_Structure_Type_Profiler, +} PJRT_Structure_Type; + +// PJRT_Structure_Base contains a type and a pointer to next +// PJRT_Structure_Base. The framework can go through this chain to find +// structure and identify it with the type. +typedef struct PJRT_Structure_Base { + PJRT_Structure_Type type; + const struct PJRT_Structure_Base* next; +} PJRT_Structure_Base; + +// -------------------------------- API access --------------------------------- + +#define _PJRT_API_STRUCT_FIELD(fn_type) fn_type* fn_type + +// Please modify PJRT_Api_STRUCT_SIZE if the last field of PJRT_Api is changed. +typedef struct PJRT_Api { + size_t struct_size; + void* extension_start; + + PJRT_Api_Version pjrt_api_version; + + _PJRT_API_STRUCT_FIELD(PJRT_Error_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_Error_Message); + _PJRT_API_STRUCT_FIELD(PJRT_Error_GetCode); + + _PJRT_API_STRUCT_FIELD(PJRT_Plugin_Initialize); + _PJRT_API_STRUCT_FIELD(PJRT_Plugin_Attributes); + + _PJRT_API_STRUCT_FIELD(PJRT_Event_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_Event_IsReady); + _PJRT_API_STRUCT_FIELD(PJRT_Event_Error); + _PJRT_API_STRUCT_FIELD(PJRT_Event_Await); + _PJRT_API_STRUCT_FIELD(PJRT_Event_OnReady); + + _PJRT_API_STRUCT_FIELD(PJRT_Client_Create); + _PJRT_API_STRUCT_FIELD(PJRT_Client_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformName); + _PJRT_API_STRUCT_FIELD(PJRT_Client_ProcessIndex); + _PJRT_API_STRUCT_FIELD(PJRT_Client_PlatformVersion); + _PJRT_API_STRUCT_FIELD(PJRT_Client_Devices); + _PJRT_API_STRUCT_FIELD(PJRT_Client_AddressableDevices); + _PJRT_API_STRUCT_FIELD(PJRT_Client_LookupDevice); + _PJRT_API_STRUCT_FIELD(PJRT_Client_LookupAddressableDevice); + _PJRT_API_STRUCT_FIELD(PJRT_Client_AddressableMemories); + _PJRT_API_STRUCT_FIELD(PJRT_Client_Compile); + _PJRT_API_STRUCT_FIELD(PJRT_Client_DefaultDeviceAssignment); + _PJRT_API_STRUCT_FIELD(PJRT_Client_BufferFromHostBuffer); + + _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_Id); + _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_ProcessIndex); + _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_Attributes); + _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_Kind); + _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_DebugString); + _PJRT_API_STRUCT_FIELD(PJRT_DeviceDescription_ToString); + + _PJRT_API_STRUCT_FIELD(PJRT_Device_GetDescription); + _PJRT_API_STRUCT_FIELD(PJRT_Device_IsAddressable); + _PJRT_API_STRUCT_FIELD(PJRT_Device_LocalHardwareId); + _PJRT_API_STRUCT_FIELD(PJRT_Device_AddressableMemories); + _PJRT_API_STRUCT_FIELD(PJRT_Device_DefaultMemory); + _PJRT_API_STRUCT_FIELD(PJRT_Device_MemoryStats); + + _PJRT_API_STRUCT_FIELD(PJRT_Memory_Id); + _PJRT_API_STRUCT_FIELD(PJRT_Memory_Kind); + _PJRT_API_STRUCT_FIELD(PJRT_Memory_DebugString); + _PJRT_API_STRUCT_FIELD(PJRT_Memory_ToString); + _PJRT_API_STRUCT_FIELD(PJRT_Memory_AddressableByDevices); + + _PJRT_API_STRUCT_FIELD(PJRT_Executable_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_Name); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_NumReplicas); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_NumPartitions); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_NumOutputs); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_SizeOfGeneratedCodeInBytes); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_GetCostAnalysis); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputMemoryKinds); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_OptimizedProgram); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_Serialize); + + _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_GetExecutable); + _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_AddressableDevices); + _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_Delete); + _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_IsDeleted); + _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_Execute); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_DeserializeAndLoad); + _PJRT_API_STRUCT_FIELD(PJRT_LoadedExecutable_Fingerprint); + + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_ElementType); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Dimensions); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_UnpaddedDimensions); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_DynamicDimensionIndices); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_GetMemoryLayout); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_OnDeviceSizeInBytes); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Device); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Memory); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_Delete); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsDeleted); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToDevice); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_ToHostBuffer); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IsOnCpu); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_ReadyEvent); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_UnsafePointer); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_IncreaseExternalReferenceCount); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_DecreaseExternalReferenceCount); + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_OpaqueDeviceMemoryDataPointer); + + _PJRT_API_STRUCT_FIELD(PJRT_CopyToDeviceStream_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_CopyToDeviceStream_AddChunk); + _PJRT_API_STRUCT_FIELD(PJRT_CopyToDeviceStream_TotalBytes); + _PJRT_API_STRUCT_FIELD(PJRT_CopyToDeviceStream_GranuleSize); + _PJRT_API_STRUCT_FIELD(PJRT_CopyToDeviceStream_CurrentBytes); + + _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_Create); + _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_Destroy); + _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_PlatformName); + _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_PlatformVersion); + _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_GetDeviceDescriptions); + _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_Serialize); + _PJRT_API_STRUCT_FIELD(PJRT_TopologyDescription_Attributes); + + _PJRT_API_STRUCT_FIELD(PJRT_Compile); + + // Always add new fields to the end of the struct. Move fields below to their + // corresponding places after each major version bump. + _PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputElementTypes); + _PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputDimensions); + + _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToMemory); + + _PJRT_API_STRUCT_FIELD(PJRT_Client_CreateViewOfDeviceBuffer); + + _PJRT_API_STRUCT_FIELD(PJRT_Executable_Fingerprint); + + _PJRT_API_STRUCT_FIELD(PJRT_Client_TopologyDescription); + + _PJRT_API_STRUCT_FIELD(PJRT_Executable_GetCompiledMemoryStats); +} PJRT_Api; + +enum { + PJRT_Api_STRUCT_SIZE = + PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_TopologyDescription) +}; + +#undef _PJRT_API_STRUCT_FIELD + +#ifdef __cplusplus +} +#endif + +#endif // XLA_PJRT_C_PJRT_C_API_H_