diff --git a/.github/workflows/cpu-inference.yml b/.github/workflows/cpu-inference.yml index 521fe2b5bea4..a2ca41f4aa3a 100644 --- a/.github/workflows/cpu-inference.yml +++ b/.github/workflows/cpu-inference.yml @@ -1,7 +1,14 @@ name: cpu-inference on: + pull_request: + paths-ignore: + - 'docs/**' + - 'blogs/**' workflow_dispatch: + merge_group: + branches: [ master ] + concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -9,7 +16,7 @@ concurrency: jobs: unit-tests: - runs-on: ubuntu-20.04 + runs-on: [self-hosted, cpu] steps: - uses: actions/checkout@v3 @@ -17,6 +24,20 @@ jobs: - id: setup-venv uses: ./.github/workflows/setup-venv + - name: Install gcc-9 + run: | + sudo add-apt-repository -u ppa:ubuntu-toolchain-r/test + sudo apt install -y gcc-9 g++-9 + # set gcc-9 and g++9 to default + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 99 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-9 99 + + - name: Check gcc version + run: | + # Get gcc version + gcc --version + g++ --version + - name: Detect instruction sets on instance run: | lscpu @@ -33,8 +54,16 @@ jobs: - name: Install oneCCL Bindings for PyTorch run: | + pip install torch python -m pip install intel_extension_for_pytorch - python -m pip install oneccl_bind_pt==2.0 -f https://developer.intel.com/ipex-whl-stable-cpu + # the curl line is for troubleshooting + curl -L https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ + python -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ + pip install py-cpuinfo + # check installed version + pip list |grep \\\ + pip list |grep intel-extension-for-pytorch + pip list |grep oneccl-bind-pt - name: Install oneCCL run: | @@ -62,14 +91,22 @@ jobs: pip install .[dev,1bit,autotuning,inf] ds_report - - name: Python environment + - name: Python environment check run: | pip list + source oneCCL/build/_install/env/setvars.sh + export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6 + # check whether the environment is properly setup + python -c "import torch;import intel_extension_for_pytorch as ipex;import oneccl_bindings_for_pytorch;print('done')" + python -c "import deepspeed;from deepspeed.accelerator import get_accelerator;print(get_accelerator().device_name());print(get_accelerator().is_available())" - name: Unit tests run: | + # prep oneCCL for CCLBackend comm ops building source oneCCL/build/_install/env/setvars.sh + export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6 unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch - cd tests - TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/ - TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' -m 'inference' unit/ + cd tests + # LOCAL_SIZE=2 enforce CPU to report 2 devices, this helps run the test on github default runner + LOCAL_SIZE=2 COLUMNS=240 TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/ + LOCAL_SIZE=2 COLUMNS=240 TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' -m 'inference' unit/ diff --git a/csrc/cpu/comm/ccl.cpp b/csrc/cpu/comm/ccl.cpp index 9a04b6f873f2..6428ab5cbfa5 100644 --- a/csrc/cpu/comm/ccl.cpp +++ b/csrc/cpu/comm/ccl.cpp @@ -537,7 +537,7 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes) } } -void inference_all_reduce(torch::Tensor& data, py::object op, std::vector group, bool async_op) +void inference_all_reduce(torch::Tensor& data, py::object op, bool async_op) { static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp"); static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value")); @@ -562,7 +562,7 @@ void inference_all_reduce(torch::Tensor& data, py::object op, std::vector g data.numel(), get_ccl_datatype(data.scalar_type()), get_ccl_reduce_op(op, data), - _get_comm_from_group(group)) + _get_comm_from_group()) .wait()); return; } diff --git a/deepspeed/comm/ccl.py b/deepspeed/comm/ccl.py index 6e915d0ca430..a199843d5086 100644 --- a/deepspeed/comm/ccl.py +++ b/deepspeed/comm/ccl.py @@ -61,7 +61,8 @@ def is_initialized(self): def run_collective(self, name, **kwargs): if name in self.available_coll: - kwargs['group'] = self.get_all_ranks_from_group(kwargs['group']) + if 'group' in kwargs: + kwargs['group'] = self.get_all_ranks_from_group(kwargs['group']) if 'dst' in kwargs: kwargs['dst'] = kwargs['group'].index(kwargs['dst']) if 'src' in kwargs: @@ -71,23 +72,38 @@ def run_collective(self, name, **kwargs): return CCLHandler(self.ccl_comm_op) else: func = "super(CCLBackend, self)." + name - return eval(func)(*(kwargs.values())) + eval(func)(*(kwargs.values())) + return CCLHandler(self.ccl_comm_op) def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False): use_caching = False if use_caching: match_id = f"{tensor.size()}-{op}" - return self.run_collective(name="all_reduce_caching", - tensor=tensor, - op=op, - match_id=match_id, - group=group, - async_op=async_op) + name = "all_reduce_caching" + if name in self.available_coll: + group = self.get_all_ranks_from_group(group) + return self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op) + else: + return self.run_collective(name=name, + tensor=tensor, + op=op, + match_id=match_id, + group=group, + async_op=async_op) else: - return self.run_collective(name="all_reduce", tensor=tensor, op=op, group=group, async_op=async_op) + name = "all_reduce" + if name in self.available_coll: + group = self.get_all_ranks_from_group(group) + return self.ccl_comm_op.all_reduce(tensor, op, group, async_op) + else: + return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op) def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False): - return self.run_collective(name="inference_all_reduce", tensor=tensor, op=op, group=group, async_op=async_op) + name = "inference_all_reduce" + if name in self.available_coll: + return self.ccl_comm_op.inference_all_reduce(tensor, op, async_op) + else: + return self.run_collective(name=name, tensor=tensor, op=op, group=None, async_op=async_op) def broadcast(self, tensor, src, group=None, async_op=False): return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op) @@ -120,11 +136,11 @@ def all_to_all_single(self, output, input, output_split_sizes, input_split_sizes input_split_sizes=input_split_sizes, group=group) - def send(self, tensor, dst, group=None, async_op=False): - return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, async_op=async_op) + def send(self, tensor, dst, group=None, tag=0): + return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, tag=tag) - def recv(self, tensor, src, group=None, async_op=False): - return self.run_collective(name="recv", tensor=tensor, src=src, group=group, async_op=async_op) + def recv(self, tensor, src, group=None, tag=0): + return self.run_collective(name="recv", tensor=tensor, src=src, group=group, tag=tag) def gather(self, tensor, gather_list, dst, group=None, async_op=False): return self.run_collective(name="gather", tensor=tensor, gather_list=gather_list, dst=dst, group=group) @@ -170,7 +186,7 @@ def get_all_ranks_from_group(self, group): while True: results.append(super(CCLBackend, self).get_global_rank(group, rank)) rank += 1 - except ValueError: + except (ValueError, RuntimeError): pass if tuple(results) not in self.groups: self._new_group(results, group) diff --git a/docs/_tutorials/accelerator-abstraction-interface.md b/docs/_tutorials/accelerator-abstraction-interface.md index 0810c3c6b5d7..db1a6005f793 100644 --- a/docs/_tutorials/accelerator-abstraction-interface.md +++ b/docs/_tutorials/accelerator-abstraction-interface.md @@ -96,7 +96,7 @@ To run DeepSpeed model on CPU, use the following steps to prepare environment: ``` python -m pip install intel_extension_for_pytorch -python -m pip install oneccl_bind_pt==2.0 -f https://developer.intel.com/ipex-whl-stable-cpu +python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu git clone https://github.com/oneapi-src/oneCCL cd oneCCL mkdir build diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 6b5588d8a1f7..767e1dba23ea 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -5,6 +5,7 @@ import os import time +import pickle import torch import pytest import itertools @@ -65,7 +66,13 @@ ] # Get a list of all models and mapping from task to supported models -_hf_models = list(HfApi().list_models()) +try: + with open("hf_models.pkl", "rb") as fp: + _hf_models = pickle.load(fp) +except FileNotFoundError: + _hf_models = list(HfApi().list_models()) + with open("hf_models.pkl", "wb") as fp: + pickle.dump(_hf_models, fp) _hf_model_names = [m.modelId for m in _hf_models] _hf_task_to_models = {task: [m.modelId for m in _hf_models if m.pipeline_tag == task] for task in _test_tasks} @@ -280,6 +287,12 @@ def test( if invalid_test_msg: pytest.skip(invalid_test_msg) + if dtype not in get_accelerator().supported_dtypes(): + pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.") + + if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) + model, task = model_w_task local_rank = int(os.getenv("LOCAL_RANK", "0")) @@ -536,9 +549,8 @@ def test( if dtype not in get_accelerator().supported_dtypes(): pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.") - # TODO: enable this test after torch 2.1 stable release if dtype == torch.bfloat16 and model_w_task[0] == "Salesforce/codegen-350M-mono": - pytest.skip("Codegen model(bf16) need to use torch version > 2.0.") + pytest.skip("Disable Codegen model(bf16) due to slight result difference") model, task = model_w_task local_rank = int(os.getenv("LOCAL_RANK", "0")) diff --git a/tests/unit/inference/test_inference_config.py b/tests/unit/inference/test_inference_config.py index 375563abf65b..39d62d17372c 100644 --- a/tests/unit/inference/test_inference_config.py +++ b/tests/unit/inference/test_inference_config.py @@ -15,7 +15,7 @@ class TestInferenceConfig(DistributedTest): world_size = 1 def test_overlap_kwargs(self): - config = {"replace_with_kernel_inject": True} + config = {"replace_with_kernel_inject": True, "dtype": torch.float32} kwargs = {"replace_with_kernel_inject": True} engine = deepspeed.init_inference(torch.nn.Module(), config=config, **kwargs) @@ -37,7 +37,7 @@ def test_kwargs_and_config(self): assert engine._config.dtype == kwargs["dtype"] def test_json_config(self, tmpdir): - config = {"replace_with_kernel_inject": True} + config = {"replace_with_kernel_inject": True, "dtype": "torch.float32"} config_json = create_config_from_dict(tmpdir, config) engine = deepspeed.init_inference(torch.nn.Module(), config=config_json)