Skip to content

Commit

Permalink
[CPU] Support Intel CPU inference (microsoft#3041)
Browse files Browse the repository at this point in the history
* add fallback path for kernels used in megatron

* temporary numactl WA for SPR 56core

* adapt core allocation according to number of ranks

* add switch to turn on numactl

* detect number of cores on the system

* allow select a subset of the cores on the system to bind

* remove unneeded changes

* add ccl backend

* change nccl to ccl

* remove unused code

* add comm/ccl to ops

* initial ccl comm support

* first broadcast case passed

* add CCL_Backend to DeepSpeed

* support comm timer for CPU

* support barrier for comm backend

* support specify master address from deepspeed command line

* support pytorch 2.0

* remove 'block' from api

* Tweak for debug

Signed-off-by: Cao, Zhong Z <[email protected]>

* Remove unecessary directory

Signed-off-by: Cao, Zhong Z <[email protected]>

* Add bf16 kernel support for inference

* Add temporary torch implement for cpu inference

* Add softmax ops cpu fallback for inference

* bind cores to numa domain as well

* merge latest change in gma/numactl

* initial bf16 kernel support with fallback path

* initial fallback path for bloom kernel injection

* fix softmax attn mask

* check KMP_AFFINITY to avoid conflict with numactl

* New CCLBackend which utilize TorchBackend for initialization

* rollback last change because there is result error

* fix bloom injection policy TP could not work issue.

injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}

* Use TorchBackend to initialize CCLBackend, make behavior consistent

* remove comm under deepspeed/ops

* add license header

* code clean up

* fix format issue

* remove magic number in main address

* add caching support but not turn on by default

* change name of inference_cuda_module to inference_module

* Check for is_synchronized_device in accelerator before get Event

* fix typo

* Fix fallback path of softmax kernel on CUDA device for BF16 data type, because CUDA tril does not support BF16 datatype, enforce fp32 data type

* add cpu backend files

* change CPU_Accelerator op_builder_dir

* remove cpu_kernel_path

* using CPU_Accelerator on non-cuda device

* fix deepspeed.op_builder => deepspeed.ops.op_builder

* add alias for num_gpus: num_accelerators

* allow loading cpu_builder in build stage

* Assume cuda available if torch not installed

* add oneccl_binding_pt to requirements

* move oneccl-binding-pt to seperate requiremetns-cpu.txt

* add missing file

* use dependency_links in setuptools.setup() call for additional dependency links

* install oneccl_bind_pt in workflows

* change oneccl_bind_pt's version from 1.13 to 2.0

* use intel_exention_for_pytorch as indicator that CPU_Accelerator should be used

* Add indicator for Accelerator used

* change foo.c to foo.cpp

* exclude 'cpu' directory in CUDA op builder reflection

* add a cpu-inference workflow

* run cpu-inference workflow on self-hosted instance

* change cpu runs-on node to v100 node

* print out python version in workflow

* add verbose in pip command to understand oneccl_bind_pt install issue

* update cpu-inference workflow

* add a stage to detect instance instruction sets

* add back bf16 support for CPU inference

* enable autoTP for bloom

Signed-off-by: Wang, Yi A <[email protected]>

* update workflow to detect cpu instruction sets

* temporary WA for Intel Extension for PyTorch AVX2 instructioon set detection

* change cpu-inference workflow machine to ubuntu-20.04

* add sharded checkpoint loading for AutoTP path to reduce the peak memory in initialization stage

Signed-off-by: Wang, Yi A <[email protected]>

* enable policy for llama

* use a special build ipex to test avx2 detection fix

* fix format

* fix test fail issue

Signed-off-by: Wang, Yi A <[email protected]>

* fix gptj sharded checkpoint loading problem

Signed-off-by: Wang, Yi A <[email protected]>

* return a not implemented build in get_op_builder in cpu_backend

* support cpu device in tests

* use cpuinfo to extract number of CPUs

* use ~/tmp as transfomer cache rather than /blob/

* Add support for mpich launcher with prefer_deepspeed_comm

* add missing modification in accelerator

* enable IMPI launcher

* remove unused file and fix formatting

* clean up ccl.cpp

* Less confusing error message when certin op builder are not implemented

* Fix license header

* Add license header

* add license headers

* add license header

* fix cuda specific code in test

* update CPU workflow

* use numactl to bind to core

* allow bind_cores_to_rank in multi-node impi runner

* fix format error

* Remove InferenceBuilder

* fix format error in numa.py

* check whether op is in installed ops in ds_report.py

* allow override accelerator with DS_ACCELERATOR='cuda','cpu' or 'xpu'

* lazy init class_dict in CUDA_Accelerator to avoid cyclic initialization of CUDA_Accelerator

* put short path in the beginning in real_accelerator.py

* device_count return number of NUMA nodes

* fix typo

* install numactl in cpu workflow

* Follow comments

* Better implementation of device_count() and current_device()

* remove dependency_link for Intel Extension for DeepSpeed

* use check is_synchronized_device in timer only once

* remove env mapping WA in cpu_accelerator

* fix duplicate definition

* fix format error

* refine ccl backend selection

* move comments to the right place

* remove prefer_deepspeed_comm, use CCLBackend by default

* refractor fallback path

* Fix execution failure in kernel injection path

* do not refractory kernel injection fallback path in  residual_add because it contains function call with side-effect

* guard residual_add fallback path with environ DS_KI_FALLBACK=True

* fix format error

* add test for allreduce on CPU workflow

* fix format error

* Fallback to TorchBackend if CCLBackend kernel are not implemented

* Update Intel Extension for Pytorch installation link

* Don't specify version number of Intel Extension for PyTorch

* install oneCCL for CCLBackend

* fix link path for CPU comm kernels

* fix source oneCCL environment

* source oneCCL env before run UT

* Give more specific instruction when CCL_ROOT not defined

---------

Signed-off-by: Cao, Zhong Z <[email protected]>
Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: sdp <[email protected]>
Co-authored-by: Cao, Zhong Z <[email protected]>
Co-authored-by: Zhenhuan Chen <[email protected]>
Co-authored-by: baodii <[email protected]>
Co-authored-by: Wang, Yi A <[email protected]>
Co-authored-by: jianan-gu <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
9 people authored May 16, 2023
1 parent 5147b90 commit 1f72082
Show file tree
Hide file tree
Showing 43 changed files with 1,414 additions and 329 deletions.
83 changes: 83 additions & 0 deletions .github/workflows/cpu-inference.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
name: cpu-inference

on:
push:
branches:
- 'staging**'
paths-ignore:
- 'docs/**'
pull_request:
paths-ignore:
- 'docs/**'

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
unit-tests:
runs-on: ubuntu-20.04

steps:
- uses: actions/checkout@v2

- id: setup-venv
uses: ./.github/workflows/setup-venv

- name: Detect instruction sets on instance
run: |
lscpu
pip install cmake
git clone https://github.com/intel/intel-extension-for-pytorch
cd intel-extension-for-pytorch/tests/cpu/isa
cmake .
make
./cpu_features
- name: Install numactl
run: |
sudo apt-get install -y numactl
- name: Install oneCCL Bindings for PyTorch
run: |
python -m pip install intel_extension_for_pytorch
python -m pip install oneccl_bind_pt==2.0 -f https://developer.intel.com/ipex-whl-stable-cpu
- name: Install oneCCL
run: |
git clone https://github.com/oneapi-src/oneCCL
cd oneCCL
mkdir build
cd build
cmake ..
make
make install
#source ./_install/env/setvars.sh
# test whether oneCCL is correctly installed
#mpirun -n 2 ./examples/benchmark/benchmark
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
cd transformers
git rev-parse --short HEAD
pip install .
- name: Install deepspeed
run: |
# check why the host does not have AVX2 support
pip install .[dev,1bit,autotuning,inf]
ds_report
- name: Python environment
run: |
pip list
- name: Unit tests
run: |
source oneCCL/build/_install/env/setvars.sh
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference' unit/inference/test_inference_config.py
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -k TestDistAllReduce unit/comm/test_dist.py
4 changes: 4 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def __init__(self):
self._name = None
self._communication_backend_name = None

@abc.abstractmethod
def is_synchronized_device(self):
...

# Device APIs
@abc.abstractmethod
def device_name(self, device_index):
Expand Down
260 changes: 260 additions & 0 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
import oneccl_bindings_for_pytorch # noqa: F401
import psutil
import os


# accelerator for Intel CPU
class CPU_Accelerator(DeepSpeedAccelerator):

def __init__(self):
self._name = 'cpu'
self._communication_backend_name = 'ccl'
self.max_mem = psutil.Process().memory_info().rss

def is_synchronized_device(self):
return True

# Device APIs
def device_name(self, device_index=None):
return 'cpu'

def device(self, device_index=None):
return None

def set_device(self, device_index):
return

def current_device(self):
return os.environ.get('LOCAL_RANK', 0)

def current_device_name(self):
return 'cpu'

def device_count(self):
device_count = int(os.environ.get('LOCAL_SIZE', 0))
if device_count > 0:
return os.environ.get('LOCAL_SIZE')
else:
from deepspeed.utils.numa import get_numa_cores
# Count NUMA node for number of cpu accelerators. On machine with HBM
# In flat mode, HBM is in separate NUMA node with no cores on this node.
# Ignore these NUMA nodes with no cores.
numa_core_lists = get_numa_cores()
numa_count = 0
for core_list in numa_core_lists:
if len(core_list) > 0:
numa_count += 1
return numa_count

def synchronize(self, device_index=None):
return

# RNG APIs
def random(self):
return torch.random

def set_rng_state(self, new_state, device_index=None):
if device_index == None:
return torch.set_rng_state(new_state)
return torch.set_rng_state(new_state, device_index)

def get_rng_state(self, device_index=None):
return torch.get_rng_state()

def manual_seed(self, seed):
return torch.manual_seed(seed)

def manual_seed_all(self, seed):
return torch.manual_seed(seed)

def initial_seed(self, seed):
return torch.initial_seed(seed)

def default_generator(self, device_index):
return torch.default_generator

# Streams/Events
@property
def Stream(self):
return None

def stream(self, stream):
from deepspeed.runtime.utils import noop_decorator
return noop_decorator

def current_stream(self, device_index=None):
return None

def default_stream(self, device_index=None):
return None

@property
def Event(self):
return None

# Memory management
def empty_cache(self):
return

def get_rss(self):
mem = psutil.Process().memory_info().rss
if mem > self.max_mem:
self.max_mem = mem
return mem

def reset_rss(self):
mem = psutil.Process().memory_info().rss
self.max_mem = mem
return mem

def memory_allocated(self, device_index=None):
return self.get_rss()

def max_memory_allocated(self, device_index=None):
self.get_rss()
return self.max_mem

def reset_max_memory_allocated(self, device_index=None):
self.reset_rss()
return

def memory_cached(self, device_index=None):
return self.get_rss()

def max_memory_cached(self, device_index=None):
self.get_rss()
return self.max_mem

def reset_max_memory_cached(self, device_index=None):
self.reset_rss()
return

def memory_stats(self, device_index=None):
return self.get_rss()

def reset_peak_memory_stats(self, device_index=None):
self.reset_rss()
return

def memory_reserved(self, device_index=None):
return self.get_rss()

def max_memory_reserved(self, device_index=None):
self.get_rss()
return self.max_mem

def total_memory(self, device_index=None):
return psutil.virtual_memory().total

# Misc
def amp(self):
return torch.cpu.amp

def is_available(self):
return True

def range_push(self, msg):
# TODO itt is currently not supported yet
# return torch.profiler.itt.range_push(msg)
return

def range_pop(self):
# TODO itt is currently not supported yet
# return torch.profiler.itt.range_pop()
return

def lazy_call(self, callback):
return callback()

def communication_backend_name(self):
return self._communication_backend_name

# Data types
def is_bf16_supported(self):
return True

def is_fp16_supported(self):
return True

# Tensor operations

@property
def BFloat16Tensor(self):
return torch.BFloat16Tensor

@property
def ByteTensor(self):
return torch.ByteTensor

@property
def DoubleTensor(self):
return torch.DoubleTensor

@property
def FloatTensor(self):
return torch.FloatTensor

@property
def HalfTensor(self):
return torch.HalfTensor

@property
def IntTensor(self):
return torch.IntTensor

@property
def LongTensor(self):
return torch.LongTensor

def pin_memory(self, tensor):
return tensor

def op_builder_dir(self):
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
return "op_builder.cpu"
except ImportError:
return "deepspeed.ops.op_builder.cpu"

def on_accelerator(self, tensor):
device_str = str(tensor.device)
if device_str.startswith('cpu'):
return True
else:
return False

# create an instance of op builder and return, name specified by class_name
def create_op_builder(self, op_name):
builder_class = self.get_op_builder(op_name)
if builder_class != None:
return builder_class()
return None

# return an op builder class, name specified by class_name
def get_op_builder(self, class_name):
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401
from op_builder.cpu import CCLCommBuilder, NotImplementedBuilder
except ImportError:
from deepspeed.ops.op_builder.cpu import CCLCommBuilder, NotImplementedBuilder

if class_name == "CCLCommBuilder":
return CCLCommBuilder
else:
# return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests
return NotImplementedBuilder

def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension
Loading

0 comments on commit 1f72082

Please sign in to comment.