Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA backend for Lc0 #1949

Merged
merged 49 commits into from
Feb 24, 2024
Merged

XLA backend for Lc0 #1949

merged 49 commits into from
Feb 24, 2024

Conversation

mooskagh
Copy link
Member

@mooskagh mooskagh commented Feb 11, 2024

This PR contains a few fairly independent modules which can be reviewed separately, here is a short description.

pjrt.cc, pjrt.h

(a C++ wrapper around PJRT C API)

PjRt is an API that takes a graph representation in the HLO format, compiles and executes it. In theory, different backends are possible, but really there's only XLA (ok, there's also IREE).
HLO is a format to define computation graph. There are many slightly different formats actually, most notably "StableHLO", and "XLA HLO". The formats have different serialization formats, most notably "text", "HloModuleProto" and MLIR (binary format used in LLVM).
For now the code uses "XLA HLO" + "HloModuleProto", which was just the easiest to do.

To make it possible to use PjRt without pulling gazillion of header dependencies, there is a PJRT C API. It works like this: a .so file, named PjRt Plugin, has a function GetPjrtAPI() (or sometime multiple, like GetGPUPjrtApi()) which returns a pointer to the PJRT_Api structure, defined in pjrt_c_api.h. By calling functions returned in this structure, you can do stuff.

See the comment below for the instruction how to build a pjrt_c_api_gpu_plugin.so XLA PJRT plugin for GPU. For TPU it's supposed to work without any changes: just download
https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/2023-09-12/libtpu.so (or newer file from here), and it should just work. Noone ever tested that though.

This API is not extremely convenient to use directly, so pjrt.cc and pjrt.h are simple wrappers of functions we use for now.

hlo_builder.cc, hlo_builder.h

(a helper class to build HloModuleProto)

HLO is an input format for XLA (which we use through PjRt API layer). hlo_builder builds HLO module in the HloModuleProto format.

HLO module consists of HLO computations (in other programming languages, they are called "functions"), one of which is "entry computation" which is usually named "main". In the current version of the PR which only supports 11248.pb.gz, module consists of just one computation (the entry one), but there are scenarios where more computations are necessary (e.g. reduction op gets reduction step computation as an attribute, or condition or loop get computations to execute on true or false condition, or as a loop body; it's also possible to call a computation just like you normally do with a function, but in automatically generated code, there's little reason for that).

HLO computation is a DAG that consists of HLO instructions, in Single Static Assignment form. Every instruction has zero or more operands (HLO instruction dependencies), some "hardcoded" attributes, and exactly one output. Shapes of operands and outputs must be defined. The last instruction of the computation is called a "root instruction", and it's the output. If we want to return multiple tensors from a computation, the root instruction is usually a tuple.

Here is an example of HLO computation in text form, just to get the feeling of it (batch size 56):

ENTRY main {
    %i0 = f32[56,112,8,8] parameter(0), metadata={op_type="input", op_name="/input/planes"}
    %i1 = f32[256,112,3,3] parameter(1), metadata={op_type="initializer", op_name="/inputconv/w/kernel"}
    %i2 = f32[256] constant(-1.40709,-0.332952,...), metadata={op_type="initializer", op_name="/inputconv/w/bias"}
    %i3 = f32[56,256,8,8] convolution(%i0, %i1), window={size=3x3 pads=1_1x1_1}, dim_labels=bf01_oi01->bf01, metadata={op_type="Conv", op_name="/inputconv"}
    %i4 = f32[56,256,8,8] broadcast(%i2), dimensions={1}, metadata={op_type="Conv", op_name="/inputconv"}
    %i5 = f32[56,256,8,8] add(%i3, %i4), metadata={op_type="Conv", op_name="/inputconv"}
    %i6 = f32[] constant(0), metadata={op_type="Relu", op_name="/inputconv/relu"}
    ...
    %i384 = f32[56,1] add(%i381, %i383), metadata={op_type="Add", op_name="/value/dense2/add"}
    %i385 = f32[56,1] tanh(%i384), metadata={op_type="Tanh", op_name="/output/value"}
    ROOT %i386 = (f32[56,1858], f32[56,1]) tuple(%i362, %i385)
}

onnx2hlo.h, onnx2hlo.cc

(Converter from ONNX graph to HLO module)

It goes node by node through the ONNX graph, and uses hlo_builder to emit the HLO graph.

xla_runner.h, xla_runner.cc

This is a module that compiles HLO for different batch sizes, keeps the compiled executable loaded, owns shared input buffers (through which large constants are passed), rounds up the sizes of input tensors to match the minibatch size, and decides which executable to run (based on the batch size).

It does most of the stuff through pjrt.h, it doesn't depend on converters etc (and has no idea what ONNX is; it operates on HLO, although HLO is also just passed through to PjRt and treated as a black box).

network_xla.cc

A Lc0 backend code. Converts the network to ONNX first, then calls onnx2hlo to convert ONNX to hlo, then calles xla_runner to compile and execute the net.

print_hlo.cc, print_hlo.h

Pretty-prints the HLO module (possibly still even compatible with the "official" format, so the output can be passed to XLA tools). Currently is not plugged anywhere (but I insert it into various places of the code when debugging the conversion), but the plan is to use it in leela2onnx for (optional) output as text HLO, and maybe dump of the optimized HLO from backend.

@mooskagh
Copy link
Member Author

This is WIP (but close to being functional).

The backend requires PjRt C API plugin. To compile XLA PjRt C API plugin, do the following:

  1. Clone the openxla/xla repo:
$ git clone https://github.com/openxla/xla.git
  1. Install Bazel and Bazelisk!.

  2. Inside the repo, call ./configure and answer questions. Sample session for CUDA support is below:

$ ./configure
You have bazel 6.5.0 installed.
Please specify the location of python. [Default is /usr/bin/python3]:


Found possible Python library paths:
  /usr/lib/python3.11/site-packages
Please input the desired Python library path to use.  Default is [/usr/lib/python3.11/site-packages]

Do you wish to build XLA with ROCm support? [y/N]:
No ROCm support will be enabled for XLA.

Do you wish to build XLA with CUDA support? [y/N]: y
CUDA support will be enabled for XLA.

Could not find any cuda.h matching version '' in any subdirectory:
        ''
        'include'
        'include/cuda'
        'include/*-linux-gnu'
        'extras/CUPTI/include'
        'include/cuda/CUPTI'
        'local/cuda/extras/CUPTI/include'
        'targets/x86_64-linux/include'
of:
        '/opt/cuda/extras/CUPTI/lib64'
        '/opt/cuda/lib64'
        '/opt/cuda/nvvm/lib64'
        '/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin'
        '/opt/intel/oneapi/compiler/latest/linux/lib'
        '/opt/intel/oneapi/compiler/latest/linux/lib/x64'
        '/opt/intel/oneapi/lib/intel64'
        '/opt/intel/oneapi/mkl/latest/lib/intel64'
        '/opt/intel/oneapi/tbb/latest/lib/intel64/gcc4.8'
        '/usr'
        '/usr/lib'
        '/usr/lib/R/lib'
        '/usr/lib/libfakeroot'
        '/usr/lib/libfakeroot/fakechroot'
        '/usr/lib32'

Asking for detailed CUDA configuration...

Please specify the CUDA SDK version you want to use. [Leave empty to default to CUDA 11]: 12


Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 2]: 8


Please specify the locally installed NCCL version you want to use. [Leave empty to use http://github.com/nvidia/nccl]:


Please specify the comma-separated list of base paths to look for CUDA libraries and headers. [Leave empty to use the default]: /opt/cuda,/usr


Found CUDA 12.3 in:
    /opt/cuda/targets/x86_64-linux/lib
    /opt/cuda/targets/x86_64-linux/include
Found cuDNN 8 in:
    /usr/lib
    /usr/include


Please specify a list of comma-separated CUDA compute capabilities you want to build with.
You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus. Each capability can be specified as "x.y" or "compute_xy" to include both virtual and binary GPU code, or as "sm_xy" to only include the binary code.
Please note that each additional compute capability significantly increases your build time and binary size, and that XLA only supports compute capabilities >= 5.2 [Default is: 5.2]:


Do you want to use clang as CUDA compiler? [y/N]:
nvcc will be used as CUDA compiler.

Please specify which gcc should be used by nvcc as the host compiler. [Default is /usr/bin/gcc-12]:


Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -Wno-sign-compare]:


Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
	--config=mkl         	# Build with MKL support.
	--config=mkl_aarch64 	# Build with oneDNN and Compute Library for the Arm Architecture (ACL).
	--config=monolithic  	# Config for mostly static monolithic build.
	--config=numa        	# Build with NUMA support.
	--config=dynamic_kernels	# (Experimental) Build kernels into separate shared objects.
	--config=v1          	# Build with TensorFlow 1 API instead of TF 2 API.
Preconfigured Bazel build configs to DISABLE default on features:
	--config=nogcp       	# Disable GCP support.
	--config=nonccl      	# Disable NVIDIA NCCL support.
Configuration finished
  1. Build the plugin (takes some hours):
$  bazel build -c opt //xla/pjrt/c:pjrt_c_api_gpu_plugin.so
  1. After that, the resulting plugin will be locates at <xla-repo>/bazel-bin/xla/pjrt/c/pjrt_c_api_gpu_plugin.so.

The elixir-nx repository has pre-compiled XLA repo, and it's possible that the .so files that they provide, also includes PjrtApi (I didn't check). It's also possible that it's there, but requires different function call (GetGpuPjrtApi() rather than GetPjrtApi(), which would be a simple change in the Lc0 code).

@mooskagh mooskagh marked this pull request as ready for review February 23, 2024 22:25
@mooskagh mooskagh requested a review from borg323 February 23, 2024 22:25
Copy link
Member

@borg323 borg323 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't test but it looks good and is well separated. Just one build system suggestion.

meson.build Outdated Show resolved Hide resolved
@mooskagh mooskagh merged commit 04f73fc into LeelaChessZero:master Feb 24, 2024
3 checks passed
PikaCat-OuO pushed a commit to official-pikafish/px0 that referenced this pull request Mar 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants