-
Notifications
You must be signed in to change notification settings - Fork 535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLA backend for Lc0 #1949
XLA backend for Lc0 #1949
Conversation
This is WIP (but close to being functional). The backend requires PjRt C API plugin. To compile XLA PjRt C API plugin, do the following:
The elixir-nx repository has pre-compiled XLA repo, and it's possible that the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't test but it looks good and is well separated. Just one build system suggestion.
(cherry picked from commit 04f73fc)
This PR contains a few fairly independent modules which can be reviewed separately, here is a short description.
pjrt.cc, pjrt.h
(a C++ wrapper around PJRT C API)
PjRt is an API that takes a graph representation in the HLO format, compiles and executes it. In theory, different backends are possible, but really there's only XLA (ok, there's also IREE).
HLO is a format to define computation graph. There are many slightly different formats actually, most notably "StableHLO", and "XLA HLO". The formats have different serialization formats, most notably "text", "HloModuleProto" and MLIR (binary format used in LLVM).
For now the code uses "XLA HLO" + "HloModuleProto", which was just the easiest to do.
To make it possible to use PjRt without pulling gazillion of header dependencies, there is a PJRT C API. It works like this: a .so file, named PjRt Plugin, has a function
GetPjrtAPI()
(or sometime multiple, likeGetGPUPjrtApi()
) which returns a pointer to thePJRT_Api
structure, defined inpjrt_c_api.h
. By calling functions returned in this structure, you can do stuff.See the comment below for the instruction how to build a
pjrt_c_api_gpu_plugin.so
XLA PJRT plugin for GPU. For TPU it's supposed to work without any changes: just downloadhttps://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/2023-09-12/libtpu.so (or newer file from here), and it should just work. Noone ever tested that though.
This API is not extremely convenient to use directly, so
pjrt.cc
andpjrt.h
are simple wrappers of functions we use for now.hlo_builder.cc, hlo_builder.h
(a helper class to build
HloModuleProto
)HLO is an input format for XLA (which we use through PjRt API layer).
hlo_builder
builds HLO module in theHloModuleProto
format.HLO module consists of HLO computations (in other programming languages, they are called "functions"), one of which is "entry computation" which is usually named "main". In the current version of the PR which only supports
11248.pb.gz
, module consists of just one computation (the entry one), but there are scenarios where more computations are necessary (e.g.reduction
op gets reduction step computation as an attribute, orcondition
orloop
get computations to execute ontrue
orfalse
condition, or as a loop body; it's also possible tocall
a computation just like you normally do with a function, but in automatically generated code, there's little reason for that).HLO computation is a DAG that consists of HLO instructions, in Single Static Assignment form. Every instruction has zero or more operands (HLO instruction dependencies), some "hardcoded" attributes, and exactly one output. Shapes of operands and outputs must be defined. The last instruction of the computation is called a "root instruction", and it's the output. If we want to return multiple tensors from a computation, the root instruction is usually a
tuple
.Here is an example of HLO computation in text form, just to get the feeling of it (batch size 56):
onnx2hlo.h, onnx2hlo.cc
(Converter from ONNX graph to HLO module)
It goes node by node through the ONNX graph, and uses
hlo_builder
to emit the HLO graph.xla_runner.h, xla_runner.cc
This is a module that compiles HLO for different batch sizes, keeps the compiled executable loaded, owns shared input buffers (through which large constants are passed), rounds up the sizes of input tensors to match the minibatch size, and decides which executable to run (based on the batch size).
It does most of the stuff through
pjrt.h
, it doesn't depend on converters etc (and has no idea what ONNX is; it operates on HLO, although HLO is also just passed through to PjRt and treated as a black box).network_xla.cc
A Lc0 backend code. Converts the network to ONNX first, then calls
onnx2hlo
to convert ONNX to hlo, then callesxla_runner
to compile and execute the net.print_hlo.cc, print_hlo.h
Pretty-prints the HLO module (possibly still even compatible with the "official" format, so the output can be passed to XLA tools). Currently is not plugged anywhere (but I insert it into various places of the code when debugging the conversion), but the plan is to use it in
leela2onnx
for (optional) output as text HLO, and maybe dump of the optimized HLO from backend.