Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Support GPT-J-6B #1332

Closed
oborchers opened this issue Aug 28, 2021 · 37 comments
Closed

[Inference] Support GPT-J-6B #1332

oborchers opened this issue Aug 28, 2021 · 37 comments
Labels
enhancement New feature or request

Comments

@oborchers
Copy link

oborchers commented Aug 28, 2021

Is your feature request related to a problem? Please describe.
With the new release of transformers, the gpt-j-6b model will be available for the public: huggingface/transformers#13022

Currently,

import os
import deepspeed
import torch
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

print(local_rank)
print(world_size)

pipeline.model = deepspeed.init_inference(
    pipeline.model,
    mp_size=1,
    dtype=torch.float16,
    replace_method='auto',
)

will only return

0
1
[2021-08-28 15:05:58,839] [INFO] [logging.py:68:log_dist] [Rank -1] DeepSpeed info: version=0.4.4, git-hash=unknown, git-branch=unknown

Deepspeed already supports the smaller gpt-neo variants, so the addition of gpt-j-6b would make sense.

Additional context
If there is anything I could do (create a PR) with some guidance I'd be happy to work on the issue and contribute as well.

@oborchers oborchers added the enhancement New feature or request label Aug 28, 2021
@RezaYazdaniAminabadi
Copy link
Contributor

Hi @oborchers

Thanks for your request.
I also agree that we need to support such model in deepspeed inference. There are however a few things that needs to be added at the API side in order to make this run:

  1. We need to have a new replacement policy for this model, as it uses different parameters at each transformer block compared to GPT-Neo. I have created a branch to start with this here.
  2. Then other important thing is that GPT-J has a different execution path compared to GPT-Neo as it is using the hidden_states as the input to both self-attention and MLP (https://github.com/StellaAthena/transformers/blob/master/src/transformers/models/gptj/modeling_gptj.py#L286). So, the inference API needs to change to reflect this both for the runtime and parallelism.
  3. Finally, there is a new term, called rotary_dim which is used for applying position-embedding for query and key (https://github.com/StellaAthena/transformers/blob/master/src/transformers/models/gptj/modeling_gptj.py#L199). We need to think on how to support it at DeepSpeed.

I think after resolving these issues, we can get this model running through DeepSpeed-Inference.
Thanks,
Reza

@oborchers
Copy link
Author

Hi @RezaYazdaniAminabadi,

much appreciated, thanks for coming back at the request! 👍 In the meantime I already took some time to understand the models behavior, the policy, and came up with something that runs. It doesn't produce anything useful (obviously), because I didn't consider 2 and 3, but it runs to some degree and can realize similar gains in inference speed compared to Neo on a V100.

Regarding 1: I came up with the following HFGPTJLayerPolicy based on the GPT2LayerPolicy and GPTNEOLayerPolicy:

class HFGPTJLayerPolicy(DSPolicy):
    _orig_layer_class = None

    def __init__(self, client_module, inference=True):
        super().__init__(inference, scale_attention=False)
        self.client_module = client_module
        try:
            import transformers

            HFGPTJLayerPolicy._orig_layer_class = (
                transformers.models.gptj.modeling_gptj.GPTJBlock
            )
        except:
            HFGPTJLayerPolicy._orig_layer_class = None

    def get_hidden_heads(self):
        return (
            self.client_module.attn.q_proj.weight.data.shape[1],
            self.client_module.attn.num_attention_heads,
        )

    def attention(self):
        qw = self.client_module.attn.q_proj.weight.data
        kw = self.client_module.attn.k_proj.weight.data
        vw = self.client_module.attn.v_proj.weight.data

        qkvw = torch.cat((qw, kw, vw), dim=0)

        return (
            self.linear_layer,
            qkvw,
            None,
            self.client_module.attn.out_proj.weight.data,
            None,
            self.scale_attention,
        )

    def mlp(self):
        return (
            self.linear_layer,
            self.client_module.mlp.fc_in.weight.data,
            self.client_module.mlp.fc_in.bias.data,
            self.client_module.mlp.fc_out.weight.data,
            self.client_module.mlp.fc_out.bias.data,
        )

    def layerNorm(self):
        return (
            None,
            None,
            self.client_module.ln_1.weight.data,
            self.client_module.ln_1.bias.data,
        )

This, however, has multiple caveats:

  1. The out_proj of the GPTJAttention comes without a bias in attention(self). Therefore, the replace_transformer_layer can not work as is, because dense_b is not expected to be None. Perhaps this can be circumvented by just initializing a zeros vector of equal size within replace_transformer_layer if dense_b is None? That's how I solved it as of now. I went through the CUDA code and it only seems to be used in additions, but maybe I am missing something.
  2. Because there is no 2nd layer norm kernel as compared to GPTNeo, layerNorm(self) also returns two None in replace_transformer_layer, so same as above. Solving this might be a bit more tricky, because those (attn_nw and attn_nb) seem to be a part of the function definitions of the CUDA ds_mlp_gemm, namely gamma and beta here. I don't understand the C++ CUDA code enough to judge if the prior trick would actually work (probably not).

Regarding 2: Yes. Also reflected in the missing second layer norm, so just altering the config with something like second_layer_norm=True doesn't seem to be enough.

Regarding 3: Not even thought about that!

Thanks for the analysis and the support of the request 👍

All the best and a nice evening,
Oliver

@zgerrard
Copy link

Any updates on this issue? @oborchers @RezaYazdaniAminabadi

@yovizzle
Copy link

I'm also very interested in this one.

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @yovizzle @zgerrard and @oborchers

Thanks for your interest. Sorry for the delay on getting back on this thread.
I am working on adding the support for this model. I will let you know soon.
Thanks,
Reza

@RezaYazdaniAminabadi
Copy link
Contributor

Can you please try this PR to see if it works for this model?
Thanks,
Reza

@oborchers
Copy link
Author

@RezaYazdaniAminabadi: Thank you for working on the issue! Much appreciated 👍

Without:
2.53 s ± 44.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
With:
2.48 s ± 25.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

(Single GPU online inference)

I'm assuming based on the PR description this is mostly targeted at multi-GPU inference due to the tensor slicing, right?

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @oborchers
Yes, it is mostly for multi-gpu, but it is nice to see on single GPU there is around 2% improvement on single GPU :) The inference kernels will be added in the next phase.
Thanks,
Reza

@joehoover
Copy link

Thanks for this, @RezaYazdaniAminabadi!

Do you have an ETA on the inference kernels for GPT-J? Even a very rough ETA would be helpful.

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @joehoover,

I am going to be more focused on this through next week. I would say it is ready by early December.
Thanks,
Reza

@oborchers
Copy link
Author

@RezaYazdaniAminabadi excellent! Thank you for that time estimate and the work on it 👍🏻

@dunalduck0
Copy link

dunalduck0 commented Dec 10, 2021

@RezaYazdaniAminabadi I tried this PR but I've got strange outcome. It worked with 1 or 2 GPUs, but crashed with 3 GPUs.

Here is my code. The input file has 10 prompts.

import os
import torch
import deepspeed
import transformers
import pandas as pd
import contextlib
import time
from deepspeed import module_inject
from transformers import pipeline
from transformers.models.gptj.modeling_gptj import GPTJBlock
from transformers import AutoTokenizer, AutoModelForCausalLM

@contextlib.contextmanager
def timer(desc="job"):
    start_time = time.time()
    print(f'{desc} starts: {time.asctime(time.localtime(start_time))}')
    yield
    end_time = time.time()
    print(f'{desc} ends: {time.asctime(time.localtime(end_time))}')
    print(f'{desc} used {end_time - start_time} seconds')   

def initPipeline(model_path_or_name, local_rank, world_size,):
    generator = pipeline(
        'text-generation', 
        model=model_path_or_name, 
        device=local_rank)

    generator.model = deepspeed.init_inference(
        generator.model,
        mp_size=world_size,
        dtype=torch.float,
        injection_policy={GPTJBlock: ('attn.out_proj','mlp.fc_out')},
        replace_with_kernel_inject=False)

    generator.model.cuda().to(f'cuda:{local_rank}')

    return generator

df = pd.read_json('test-input.txt', encoding='utf-8', lines=True)
input_prompts = df['prompt'].tolist()

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

with timer(f"process {local_rank} init"):
    generator = initPipeline("EleutherAI/gpt-j-6B", local_rank, world_size)

with timer(f"process {local_rank} generate"):
    generator(input_prompts, max_new_tokens = 50)

When running with --include localhost:0 or --include localhost:0,1, it worked fine. It crashed with 3 GPUs.

deepspeed --include localhost:0,1,2 ds_pipeline.py

[2021-12-09 23:19:15,962] [WARNING] [runner.py:132:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2021-12-09 23:19:21,261] [INFO] [runner.py:398:main] cmd = /home/meiyang/bin/miniconda3/envs/gptj/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMSwgMl19 --master_addr=127.0.0.1 --master_port=29500 ds_pipeline.py
[2021-12-09 23:19:22,250] [INFO] [launch.py:80:main] WORLD INFO DICT: {'localhost': [0, 1, 2]}
[2021-12-09 23:19:22,251] [INFO] [launch.py:86:main] nnodes=1, num_local_procs=3, node_rank=0
[2021-12-09 23:19:22,251] [INFO] [launch.py:99:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0, 1, 2]})
[2021-12-09 23:19:22,251] [INFO] [launch.py:100:main] dist_world_size=3
[2021-12-09 23:19:22,251] [INFO] [launch.py:102:main] Setting CUDA_VISIBLE_DEVICES=0,1,2
process 2 init starts: Thu Dec  9 23:19:35 2021process 1 init starts: Thu Dec  9 23:19:35 2021

process 0 init starts: Thu Dec  9 23:19:35 2021
[2021-12-09 23:20:42,094] [INFO] [logging.py:69:log_dist] [Rank -1] DeepSpeed info: version=0.5.8, git-hash=unknown, git-branch=unknown
[2021-12-09 23:20:42,095] [INFO] [engine.py:127:_init_quantization_setting] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
[2021-12-09 23:20:42,097] [INFO] [distributed.py:46:init_distributed] Initializing torch distributed with backend: nccl
[2021-12-09 23:20:42,102] [INFO] [logging.py:69:log_dist] [Rank -1] DeepSpeed info: version=0.5.8, git-hash=unknown, git-branch=unknown
[2021-12-09 23:20:42,103] [INFO] [engine.py:127:_init_quantization_setting] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
[2021-12-09 23:20:42,105] [INFO] [distributed.py:46:init_distributed] Initializing torch distributed with backend: nccl
[2021-12-09 23:20:42,142] [INFO] [logging.py:69:log_dist] [Rank -1] DeepSpeed info: version=0.5.8, git-hash=unknown, git-branch=unknown
[2021-12-09 23:20:42,142] [INFO] [engine.py:127:_init_quantization_setting] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
[2021-12-09 23:20:42,144] [INFO] [distributed.py:46:init_distributed] Initializing torch distributed with backend: nccl
[2021-12-09 23:20:43,141] [INFO] [engine.py:91:__init__] Place model to device: 0
[2021-12-09 23:20:43,144] [INFO] [engine.py:91:__init__] Place model to device: 1
process 0 init ends: Thu Dec  9 23:20:43 2021
process 0 init used 67.92620539665222 seconds
process 0 generate starts: Thu Dec  9 23:20:43 2021
[2021-12-09 23:20:43,145] [INFO] [engine.py:91:__init__] Place model to device: 2
process 1 init ends: Thu Dec  9 23:20:43 2021
process 1 init used 67.93008637428284 seconds
process 1 generate starts: Thu Dec  9 23:20:43 2021
process 2 init ends: Thu Dec  9 23:20:43 2021
process 2 init used 67.93062329292297 seconds
process 2 generate starts: Thu Dec  9 23:20:43 2021
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Traceback (most recent call last):
Traceback (most recent call last):
  File "ds_pipeline.py", line 49, in <module>
  File "ds_pipeline.py", line 49, in <module>
Traceback (most recent call last):
  File "ds_pipeline.py", line 49, in <module>
    generator(input_prompts, max_new_tokens = 50)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/text_generation.py", line 171, in __call__
    generator(input_prompts, max_new_tokens = 50)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/text_generation.py", line 171, in __call__
    generator(input_prompts, max_new_tokens = 50)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/text_generation.py", line 171, in __call__
    return super().__call__(text_inputs, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 1086, in __call__
    return super().__call__(text_inputs, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 1086, in __call__
    return super().__call__(text_inputs, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 1086, in __call__
    outputs = [output for output in final_iterator]
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 1086, in <listcomp>
    outputs = [output for output in final_iterator]
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 1086, in <listcomp>
    outputs = [output for output in final_iterator]
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 1086, in <listcomp>
    outputs = [output for output in final_iterator]
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 771, in __next__
    outputs = [output for output in final_iterator]
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 771, in __next__
    item = next(self.iterator)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 772, in __next__
    outputs = [output for output in final_iterator]
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 771, in __next__
    item = next(self.iterator)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 772, in __next__
    processed = self.infer(item, **self.params)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 1036, in forward
    processed = self.infer(item, **self.params)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 1036, in forward
    item = next(self.iterator)
      File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 772, in __next__
model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/text_generation.py", line 206, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs)  # BS x SL
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/text_generation.py", line 206, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs)  # BS x SL
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    processed = self.infer(item, **self.params)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/base.py", line 1036, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/meiyang/src/transformers_fork/src/transformers/pipelines/text_generation.py", line 206, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs)  # BS x SL
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/generation_utils.py", line 1033, in generate
    return func(*args, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/generation_utils.py", line 1033, in generate
    return func(*args, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/generation_utils.py", line 1033, in generate
    return self.sample(
  File "/home/meiyang/src/transformers_fork/src/transformers/generation_utils.py", line 1547, in sample
        return self.sample(return self.sample(

  File "/home/meiyang/src/transformers_fork/src/transformers/generation_utils.py", line 1547, in sample
  File "/home/meiyang/src/transformers_fork/src/transformers/generation_utils.py", line 1547, in sample
    outputs = self(
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
        outputs = self(outputs = self(

  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 244, in forward
    return forward_call(*input, **kwargs)
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 244, in forward
    return forward_call(*input, **kwargs)
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 244, in forward
    outputs = self.model_orig_fwd(*inputs, **kwargs)
      File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 782, in forward
outputs = self.model_orig_fwd(*inputs, **kwargs)
outputs = self.model_orig_fwd(*inputs, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 782, in forward
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 782, in forward
    transformer_outputs = self.transformer(
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    transformer_outputs = self.transformer(
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    transformer_outputs = self.transformer(
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 636, in forward
    return forward_call(*input, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 636, in forward
    outputs = block(
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 636, in forward
    outputs = block(
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 280, in forward
    outputs = block(
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    attn_outputs = self.attn(
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 280, in forward
    return forward_call(*input, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 185, in forward
    return forward_call(*input, **kwargs)
      File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 280, in forward
attn_outputs = self.attn(
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 110, in _split_heads
    tensor = tensor.view(*new_shape)
RuntimeError: shape '[1, 284, 5, 256]' is invalid for input of size 387660
    attn_outputs = self.attn(
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 185, in forward
    query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 110, in _split_heads
    return forward_call(*input, **kwargs)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 185, in forward
    tensor = tensor.view(*new_shape)
RuntimeError: shape '[1, 284, 5, 256]' is invalid for input of size 387660
    query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
  File "/home/meiyang/src/transformers_fork/src/transformers/models/gptj/modeling_gptj.py", line 110, in _split_heads
    tensor = tensor.view(*new_shape)
RuntimeError: shape '[1, 284, 5, 256]' is invalid for input of size 387660
[2021-12-09 23:20:45,373] [INFO] [launch.py:131:sigkill_handler] Killing subprocess 1103263
[2021-12-09 23:20:45,373] [INFO] [launch.py:131:sigkill_handler] Killing subprocess 1103264
[2021-12-09 23:20:45,373] [INFO] [launch.py:131:sigkill_handler] Killing subprocess 1103265
[2021-12-09 23:20:45,373] [ERROR] [launch.py:137:sigkill_handler] ['/home/meiyang/bin/miniconda3/envs/gptj/bin/python', '-u', 'ds_pipeline.py', '--local_rank=2'] exits with return code = 1

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @dunalduck0

Thanks for trying this.
For the model-parallelism work, the dimensions should be divisible by the number of GPUs you are using. Since, it cannot divide it properly, you get this error at one of the reshaping part of the transformer. However, I think this needs to be errored out properly so that it better illustrate this.

Best,
Reza

@dunalduck0
Copy link

Thank you @RezaYazdaniAminabadi . For text-generation task, input lengths are normally varying. Does it mean we need to pad the input so that the dimensions are divisible? If so, how do I do that?

@RezaYazdaniAminabadi
Copy link
Contributor

Sorry, I meant the model dimensions, such as hidden-size and number of attention heads. This is due to partitioning the weights across GPUs. The input will not be however partitioned but broadcasted to GPUs.

@dunalduck0
Copy link

dunalduck0 commented Dec 12, 2021 via email

@RezaYazdaniAminabadi
Copy link
Contributor

Yes, this is the case. Even, for some models, like GPT2 which has 25 heads, the lowest mp_size we can set after 1 is 5. So, we have this constraint based on the model structure.

@dunalduck0
Copy link

dunalduck0 commented Dec 14, 2021

It sounds like you wanted to split the model into equal sizes to GPUs. But why can't you split by different sizes? Say first two GPUs have 9 heads each, the 3rd one has 8 heads, so total 25 heads?

Sorry that I actually don't know much about DL model structures. Apologize in advance if my suggestion is wildly stupid :P.

@Leezekun
Copy link

Leezekun commented Dec 24, 2021

Hi @joehoover,

I am going to be more focused on this through next week. I would say it is ready by early December. Thanks, Reza!

Any update on the GPT-J-6B Inference kernel?
Thanks!

@Leezekun
Copy link

Hi @dunalduck0

Thanks for trying this. For the model-parallelism work, the dimensions should be divisible by the number of GPUs you are using. Since, it cannot divide it properly, you get this error at one of the reshaping part of the transformer. However, I think this needs to be errored out properly so that it better illustrate this.

Best, Reza

Hi, I also found that when using deepspeed to speed up gpt-j inference, only 1 or 2 GPUs works, but crashed with 3 GPUs.

Besides, I found that using 2 GPUs is not faster than using 1, except that the large model can be modeled into these 2 GPUs.
Is this expected? What should I do if I want to not only parallel the model inference and solve the OOM problem, but also speed up the GPT-J inference?

Thanks!

@joehoover
Copy link

Hi @Leezekun,

A large decrease in latency is not an expected result of mere model parallelism. That's what the kernels are for.

Also, MP for GPT-J won't work with 3 devices. The degree of MP is constrained by the dimensions of the model.

@Leezekun
Copy link

Hi @joehoover,

Thanks for the clarification. Do you know any updates on the GPT-J-6B Inference kernel?

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @Leezekun

That is true that the parallelism alone may not improve performance and I agree with @joehoover that it requires kernels for getting higher performance.
I was planning to release the kernel support in December, but I fell behind my schedule! I will follow up on this thread more next week.

Thanks,
Reza

@Leezekun
Copy link

Leezekun commented Jan 2, 2022

@RezaYazdaniAminabadi

Thanks a lot! Look forward to the update.

@RezaYazdaniAminabadi
Copy link
Contributor

Hi everyone,

I have added this PR to run GPTJ model through DeepSpeed. Can you please try it and see if it works on your side.
I use this script:

import os
import torch
import deepspeed
import transformers

from deepspeed import module_inject
from transformers import pipeline
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoBlock as gpt2_transformer

# Get local gpu rank from torch.distributed/deepspeed launcher
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

print(
    "***************** Creating model in RANK ({0}) with WORLD_SIZE = {1} *****************"
    .format(local_rank,
            world_size))
generator = pipeline('text-generation',
                     model='EleutherAI/gpt-j-6B',
                     #model='EleutherAI/gpt-neo-2.7B',
                     device=local_rank)
generator.model = deepspeed.init_inference(generator.model,
                                           mp_size=world_size,
                                           dtype=torch.float,
                                           replace_method='auto',
                                           replace_with_kernel_inject=True)
string = generator("DeepSpeed is", do_sample=True, min_length=50)
print(string)

Best,
Reza

@Leezekun
Copy link

Leezekun commented Jan 8, 2022

@RezaYazdaniAminabadi

Thanks for the PR. But when I tried it using the script you provided, I got the following errors:

[2022-01-07 23:32:02,293] [WARNING] [runner.py:132:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2022-01-07 23:32:02,924] [INFO] [runner.py:398:main] cmd = /home/zekun/miniconda3/envs/gptj-torch1.9/bin/python3.6 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMV19 --master_addr=127.0.0.1 --master_port=29500 interact_gpt_deepspeed.py
[2022-01-07 23:32:03,726] [INFO] [launch.py:80:main] WORLD INFO DICT: {'localhost': [0, 1]}
[2022-01-07 23:32:03,726] [INFO] [launch.py:87:main] nnodes=1, num_local_procs=2, node_rank=0
[2022-01-07 23:32:03,726] [INFO] [launch.py:99:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0, 1]})
[2022-01-07 23:32:03,726] [INFO] [launch.py:100:main] dist_world_size=2
[2022-01-07 23:32:03,726] [INFO] [launch.py:102:main] Setting CUDA_VISIBLE_DEVICES=0,1
***************** Creating model in RANK (0) with WORLD_SIZE = 2 *****************
***************** Creating model in RANK (1) with WORLD_SIZE = 2 *****************
[2022-01-07 23:33:07,668] [INFO] [logging.py:69:log_dist] [Rank -1] DeepSpeed info: version=0.5.9+8a99292, git-hash=8a99292, git-branch=gptj-inference-support
[2022-01-07 23:33:07,669] [INFO] [engine.py:127:_init_quantization_setting] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
[2022-01-07 23:33:07,672] [INFO] [distributed.py:47:init_distributed] Initializing torch distributed with backend: nccl
[2022-01-07 23:33:19,822] [INFO] [logging.py:69:log_dist] [Rank -1] DeepSpeed info: version=0.5.9+8a99292, git-hash=8a99292, git-branch=gptj-inference-support
[2022-01-07 23:33:19,822] [INFO] [engine.py:127:_init_quantization_setting] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
[2022-01-07 23:33:19,825] [INFO] [distributed.py:47:init_distributed] Initializing torch distributed with backend: nccl
Installed CUDA version 11.2 does not match the version torch was compiled with 11.1 but since the APIs are compatible, accepting this combination
Using /home/zekun/.cache/torch_extensions as PyTorch extensions root...
Installed CUDA version 11.2 does not match the version torch was compiled with 11.1 but since the APIs are compatible, accepting this combination
Using /home/zekun/.cache/torch_extensions as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/zekun/.cache/torch_extensions/transformer_inference/build.ninja...
Building extension module transformer_inference...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/7] /usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/includes -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/TH -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/include/python3.6m -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86 -c /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu -o apply_rotary_pos_emb.cuda.o 
FAILED: apply_rotary_pos_emb.cuda.o 
/usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/includes -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/TH -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/include/python3.6m -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86 -c /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu -o apply_rotary_pos_emb.cuda.o 
/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu(71): error: identifier "lane" is undefined

1 error detected in the compilation of "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu".
[2/7] /usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/includes -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/TH -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/include/python3.6m -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86 -c /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/dequantize.cu -o dequantize.cuda.o 
[3/7] /usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/includes -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/TH -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/include/python3.6m -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86 -c /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/gelu.cu -o gelu.cuda.o 
[4/7] /usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/includes -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/TH -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/include/python3.6m -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86 -c /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/softmax.cu -o softmax.cuda.o 
[5/7] /usr/local/cuda/bin/nvcc  -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/includes -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/TH -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/include/python3.6m -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_86,code=compute_86 -c /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/normalize.cu -o normalize.cuda.o 
/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/normalize.cu(21): warning: variable "iterations" was declared but never referenced

/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/normalize.cu(90): warning: variable "iterations" was declared but never referenced

[6/7] c++ -MMD -MF pt_binding.o.d -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/includes -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/TH -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/zekun/miniconda3/envs/gptj-torch1.9/include/python3.6m -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -O3 -std=c++14 -g -Wno-reorder -c /home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/pt_binding.cpp -o pt_binding.o 
ninja: build stopped: subcommand failed.
Traceback (most recent call last):
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1672, in _run_ninja_build
    env=env)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/subprocess.py", line 438, in run
    output=stdout, stderr=stderr)
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "interact_gpt_deepspeed.py", line 209, in <module>
    replace_with_kernel_inject=True)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/__init__.py", line 283, in init_inference
    replace_with_kernel_inject)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/inference/engine.py", line 88, in __init__
    replace_with_kernel_inject=replace_with_kernel_inject)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/inference/engine.py", line 175, in _apply_injection_policy
    replace_with_kernel_inject=replace_with_kernel_inject)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 467, in replace_transformer_layer
    _replace_policy=policy)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 561, in replace_module
    replaced_module, _ = _replace_module(model, policy)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 583, in _replace_module
    _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 583, in _replace_module
    _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 580, in _replace_module
    layer_id))
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 458, in replace_fn
    layer_id=layer_id)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 264, in replace_with_policy
    mp_group=mp_group,
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 571, in __init__
    inference_cuda_module = op_builder.InferenceBuilder().load()
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/op_builder/builder.py", line 403, in load
    return self.jit_load(verbose)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/op_builder/builder.py", line 442, in jit_load
    verbose=verbose)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1092, in load
    keep_intermediates=keep_intermediates)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1303, in _jit_compile
    is_standalone=is_standalone)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1408, in _write_ninja_file_and_build_library
    error_prefix=f"Error building extension '{name}'")
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1682, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error building extension 'transformer_inference'
Loading extension module transformer_inference...
Traceback (most recent call last):
  File "interact_gpt_deepspeed.py", line 209, in <module>
    replace_with_kernel_inject=True)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/__init__.py", line 283, in init_inference
    replace_with_kernel_inject)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/inference/engine.py", line 88, in __init__
    replace_with_kernel_inject=replace_with_kernel_inject)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/inference/engine.py", line 175, in _apply_injection_policy
    replace_with_kernel_inject=replace_with_kernel_inject)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 467, in replace_transformer_layer
    _replace_policy=policy)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 561, in replace_module
    replaced_module, _ = _replace_module(model, policy)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 583, in _replace_module
    _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 583, in _replace_module
    _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 580, in _replace_module
    layer_id))
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 458, in replace_fn
    layer_id=layer_id)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/module_inject/replace_module.py", line 264, in replace_with_policy
    mp_group=mp_group,
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/transformer/inference/transformer_inference.py", line 571, in __init__
    inference_cuda_module = op_builder.InferenceBuilder().load()
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/op_builder/builder.py", line 403, in load
    return self.jit_load(verbose)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/deepspeed/ops/op_builder/builder.py", line 442, in jit_load
    verbose=verbose)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1092, in load
    keep_intermediates=keep_intermediates)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1318, in _jit_compile
    return _import_module_from_library(name, build_directory, is_python_module)
  File "/home/zekun/miniconda3/envs/gptj-torch1.9/lib/python3.6/site-packages/torch/utils/cpp_extension.py", line 1701, in _import_module_from_library
    module = importlib.util.module_from_spec(spec)
  File "<frozen importlib._bootstrap>", line 571, in module_from_spec
  File "<frozen importlib._bootstrap_external>", line 922, in create_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
ImportError: /home/zekun/.cache/torch_extensions/transformer_inference/transformer_inference.so: cannot open shared object file: No such file or directory
[2022-01-07 23:33:39,856] [INFO] [launch.py:131:sigkill_handler] Killing subprocess 2311
[2022-01-07 23:33:39,856] [INFO] [launch.py:131:sigkill_handler] Killing subprocess 2312
[2022-01-07 23:33:39,856] [ERROR] [launch.py:137:sigkill_handler] ['/home/zekun/miniconda3/envs/gptj-torch1.9/bin/python3.6', '-u', 'interact_gpt_deepspeed.py', '--local_rank=1'] exits with return code = 1

@oborchers
Copy link
Author

Awesome @RezaYazdaniAminabadi! Highly appreciated and thanks for tackling the issue 💯
For me, I am running into the following problem:

CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:
....
RuntimeError: Error building extension 'transformer_inference'

Steps to replicate:

  1. git clone repository
  2. python setup.py bdist_wheel
  3. ds_report
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
sparse_attn ............ [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
async_io ............... [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch']
torch version .................... 1.9.0+cu111
torch cuda version ............... 11.1
nvcc version ..................... 11.1
deepspeed install path ........... ['/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.5.10+289c3f9, 289c3f9, master
deepspeed wheel compiled w. ...... torch 1.9, cuda 11.1
  1. nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 455.32.00    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+

The same happens when I do:

  • pip install .[dev,1bit,autotuning] from the CI
  • re-run the script with torch==1.8.1
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
sparse_attn ............ [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
async_io ............... [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch']
torch version .................... 1.8.1+cu111
torch cuda version ............... 11.1
nvcc version ..................... 11.1
deepspeed install path ........... ['/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.5.10+289c3f9, 289c3f9, master
deepspeed wheel compiled w. ...... torch 1.8, cuda 11.1

Upgrading ninja did also not work: Successfully installed ninja-1.10.2.3

and when running: DS_BUILD_OPS=1 pip install .[dev,1bit,autotuning] --global-option="build_ext" --global-option="-j8":

creating build/lib.linux-x86_64-3.8/deepspeed/ops/quantizer
    g++ -pthread -shared -B /home/oborchers/anaconda3/envs/dev/compiler_compat -L/home/oborchers/anaconda3/envs/dev/lib -Wl,-rpath=/home/oborchers/anaconda3/envs/dev/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.8/csrc/quantization/pt_binding.o build/temp.linux-x86_64-3.8/csrc/quantization/quantizer.o -L/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/lib -L/usr/local/cuda-11.1/lib64 -L/usr/local/cuda-11.1/lib64 -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda_cu -ltorch_cuda_cpp -o build/lib.linux-x86_64-3.8/deepspeed/ops/quantizer/quantizer_op.cpython-38-x86_64-linux-gnu.so
    /usr/local/cuda-11.1/bin/nvcc -Icsrc/includes -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/TH -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.1/include -I/usr/local/cuda-11.1/include -I/home/oborchers/anaconda3/envs/dev/include/python3.8 -c csrc/transformer/normalize_kernels.cu -o build/temp.linux-x86_64-3.8/csrc/transformer/normalize_kernels.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -D__STOCHASTIC_MODE__ -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -DTORCH_EXTENSION_NAME=stochastic_transformer_op -D_GLIBCXX_USE_CXX11_ABI=0
    /usr/local/cuda-11.1/bin/nvcc -Icsrc/includes -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/TH -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.1/include -I/usr/local/cuda-11.1/include -I/home/oborchers/anaconda3/envs/dev/include/python3.8 -c csrc/transformer/normalize_kernels.cu -o build/temp.linux-x86_64-3.8/csrc/transformer/normalize_kernels.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -DTORCH_EXTENSION_NAME=transformer_op -D_GLIBCXX_USE_CXX11_ABI=0
    csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

    csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

    csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

    csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

    csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

    csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

    csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

    csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

    csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

    csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

    /usr/local/cuda-11.1/bin/nvcc -Icsrc/includes -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/TH -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.1/include -I/usr/local/cuda-11.1/include -I/home/oborchers/anaconda3/envs/dev/include/python3.8 -c csrc/transformer/softmax_kernels.cu -o build/temp.linux-x86_64-3.8/csrc/transformer/softmax_kernels.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -D__STOCHASTIC_MODE__ -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -DTORCH_EXTENSION_NAME=stochastic_transformer_op -D_GLIBCXX_USE_CXX11_ABI=0
    /usr/local/cuda-11.1/bin/nvcc -Icsrc/includes -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/TH -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.1/include -I/usr/local/cuda-11.1/include -I/home/oborchers/anaconda3/envs/dev/include/python3.8 -c csrc/transformer/softmax_kernels.cu -o build/temp.linux-x86_64-3.8/csrc/transformer/softmax_kernels.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -DTORCH_EXTENSION_NAME=transformer_op -D_GLIBCXX_USE_CXX11_ABI=0
    /usr/local/cuda-11.1/bin/nvcc -Icsrc/includes -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/TH -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.1/include -I/usr/local/cuda-11.1/include -I/home/oborchers/anaconda3/envs/dev/include/python3.8 -c csrc/transformer/general_kernels.cu -o build/temp.linux-x86_64-3.8/csrc/transformer/general_kernels.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -D__STOCHASTIC_MODE__ -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -DTORCH_EXTENSION_NAME=stochastic_transformer_op -D_GLIBCXX_USE_CXX11_ABI=0
    /usr/local/cuda-11.1/bin/nvcc -Icsrc/includes -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/TH -I/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.1/include -I/usr/local/cuda-11.1/include -I/home/oborchers/anaconda3/envs/dev/include/python3.8 -c csrc/transformer/general_kernels.cu -o build/temp.linux-x86_64-3.8/csrc/transformer/general_kernels.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -DTORCH_EXTENSION_NAME=transformer_op -D_GLIBCXX_USE_CXX11_ABI=0
creating build/lib.linux-x86_64-3.8/deepspeed/ops/transformer
    g++ -pthread -shared -B /home/oborchers/anaconda3/envs/dev/compiler_compat -L/home/oborchers/anaconda3/envs/dev/lib -Wl,-rpath=/home/oborchers/anaconda3/envs/dev/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.8/csrc/transformer/ds_transformer_cuda.o build/temp.linux-x86_64-3.8/csrc/transformer/cublas_wrappers.o build/temp.linux-x86_64-3.8/csrc/transformer/transform_kernels.o build/temp.linux-x86_64-3.8/csrc/transformer/gelu_kernels.o build/temp.linux-x86_64-3.8/csrc/transformer/dropout_kernels.o build/temp.linux-x86_64-3.8/csrc/transformer/normalize_kernels.o build/temp.linux-x86_64-3.8/csrc/transformer/softmax_kernels.o build/temp.linux-x86_64-3.8/csrc/transformer/general_kernels.o -L/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/lib -L/usr/local/cuda-11.1/lib64 -L/usr/local/cuda-11.1/lib64 -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda_cu -ltorch_cuda_cpp -o build/lib.linux-x86_64-3.8/deepspeed/ops/transformer/stochastic_transformer_op.cpython-38-x86_64-linux-gnu.so
    g++ -pthread -shared -B /home/oborchers/anaconda3/envs/dev/compiler_compat -L/home/oborchers/anaconda3/envs/dev/lib -Wl,-rpath=/home/oborchers/anaconda3/envs/dev/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.8/csrc/transformer/ds_transformer_cuda.o build/temp.linux-x86_64-3.8/csrc/transformer/cublas_wrappers.o build/temp.linux-x86_64-3.8/csrc/transformer/transform_kernels.o build/temp.linux-x86_64-3.8/csrc/transformer/gelu_kernels.o build/temp.linux-x86_64-3.8/csrc/transformer/dropout_kernels.o build/temp.linux-x86_64-3.8/csrc/transformer/normalize_kernels.o build/temp.linux-x86_64-3.8/csrc/transformer/softmax_kernels.o build/temp.linux-x86_64-3.8/csrc/transformer/general_kernels.o -L/home/oborchers/anaconda3/envs/dev/lib/python3.8/site-packages/torch/lib -L/usr/local/cuda-11.1/lib64 -L/usr/local/cuda-11.1/lib64 -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda_cu -ltorch_cuda_cpp -o build/lib.linux-x86_64-3.8/deepspeed/ops/transformer/transformer_op.cpython-38-x86_64-linux-gnu.so
    error: command '/usr/local/cuda-11.1/bin/nvcc' failed with exit status 1
    ----------------------------------------
ERROR: Command errored out with exit status 1: /home/oborchers/anaconda3/envs/dev/bin/python -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-0t_ejf58/setup.py'"'"'; __file__='"'"'/tmp/pip-req-build-0t_ejf58/setup.py'"'"';f = getattr(tokenize, '"'"'open'"'"', open)(__file__) if os.path.exists(__file__) else io.StringIO('"'"'from setuptools import setup; setup()'"'"');code = f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' build_ext -j8 install --record /tmp/pip-record-mzks2_nm/install-record.txt --single-version-externally-managed --compile --install-headers /home/oborchers/anaconda3/envs/dev/include/python3.8/deepspeed Check the logs for full command output.

Did I miss something?

@oborchers
Copy link
Author

oborchers commented Jan 8, 2022

Does also not work in a clean environment:

[1/7] /usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -I/workspace/DeepSpeed/deepspeed/ops/csrc/transformer/inference/includes -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_52,code=sm_52 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_60,code=sm_60 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_70,code=compute_70 -c /workspace/DeepSpeed/deepspeed/ops/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu -o apply_rotary_pos_emb.cuda.o
FAILED: apply_rotary_pos_emb.cuda.o
RuntimeError: Error building extension 'transformer_inference'

Tested on docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:20.12-py3

@RezaYazdaniAminabadi
Copy link
Contributor

@oborchers @Leezekun,

Thanks for trying this, yes, there is some issue on the half-precision kernels. I am creating another PR to fix this.

@oborchers
Copy link
Author

@RezaYazdaniAminabadi: This is working! Great job! 💯 In terms of performance:

Pytorch (1.9 + cu111) + transformers 4.15 + 1x V100 and max_length = 64 tokens:

generator("DeepSpeed is", do_sample=True, max_length=max_length, pad_token_id=50256, eos_token_id=50256)

Without kernels:

3.44 s ± 11.6 ms per loop (mean ± std. dev. of 25 runs, 1 loop each)

With kernels:

1.97 s ± 788 µs per loop (mean ± std. dev. of 25 runs, 1 loop each)

-> Decrease: ~43%
-> Generated Result:

[{'generated_text': 'DeepSpeed is not a well-known brand. In fact, the main brand that you would find at the top of a Google search is DeepSpeed Technology. So what’s the difference? I was lucky enough to get a chance to talk with DeepSpeed CEO and Co-Founder, David Schubert'}]

Without kernels (fp16):

2.99 s ± 45.3 ms per loop (mean ± std. dev. of 25 runs, 1 loop each)

With kernels (fp16):

1.1 s ± 1.4 ms per loop (mean ± std. dev. of 25 runs, 1 loop each)

-> Decrease: ~63% 🚀
-> Generated Result:

[{'generated_text': 'DeepSpeed is a fast growing startup that helps you to find the right doctors and medical clinics. It connects patients to the doctors who are closest to them, based on their location and their personal doctor recommendations. DeepSpeed was previously known as DocSense and was launched in 2007. DeepSpeed is backed by Accel Partners,'}]

But, still some issues remain. But this may not necessarily relate to this very issue, if I am correct. Shall I open a new issue for this?

  1. Install only partially works:
  • python setup.py bdist_wheel + pip install dist/*.whl works
  • DS_BUILD_OPS=1 pip install . --global-option="build_ext" --global-option="-j8" works only in the container: nvcr.io/nvidia/pytorch:20.12-py3 (didn't test any others).

Results in:

csrc/transformer/normalize_kernels.cu(1044): warning: variable "block_dim" was declared but never referenced

csrc/transformer/normalize_kernels.cu(896): error: no operator "*=" matches these operands
            operand types are: __half2 *= __half2

csrc/transformer/normalize_kernels.cu(899): error: no operator "-" matches these operands
            operand types are: const __half2 - const __half2

csrc/transformer/normalize_kernels.cu(901): error: ambiguous "?" operation: second operand of type "<error-type>" can be converted to third operand type "const __half2", and vice versa

csrc/transformer/normalize_kernels.cu(906): error: no operator "*=" matches these operands
            operand types are: __half2 *= __half2

csrc/transformer/normalize_kernels.cu(908): error: no operator "-" matches these operands
            operand types are: const __half2 - const __half2

csrc/transformer/normalize_kernels.cu(909): error: ambiguous "?" operation: second operand of type "<error-type>" can be converted to third operand type "const __half2", and vice versa
....
csrc/transformer/normalize_kernels.cu(2045): error: no operator "+" matches these operands
            operand types are: __half2 + const __half2

39 errors detected in the compilation of "csrc/transformer/normalize_kernels.cu".
error: command '/usr/local/cuda/bin/nvcc' failed with exit status 1
  1. Running this version for for gpt-neo-1.3b results in the following error:
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-112548c80691> in <module>
     21                      device=local_rank)
     22 
---> 23 generator.model = deepspeed.init_inference(
     24     generator.model,
     25     mp_size=world_size,

~/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed/__init__.py in init_inference(model, mp_size, mpu, checkpoint, module_key, dtype, injection_policy, replace_method, quantization_setting, replace_with_kernel_inject, return_tuple)
    272         raise NotImplementedError("pipeline module support is not implemented yet")
    273     else:
--> 274         engine = InferenceEngine(model,
    275                                  mp_size,
    276                                  mpu,

~/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed/inference/engine.py in __init__(self, model, mp_size, mpu, checkpoint, dtype, injection_dict, return_tuple, replace_method, quantization_setting, replace_with_kernel_inject)
     84                                              replace_with_kernel_inject)
     85         elif replace_method == 'auto':
---> 86             self._apply_injection_policy(
     87                 return_tuple=return_tuple,
     88                 replace_with_kernel_inject=replace_with_kernel_inject)

~/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed/inference/engine.py in _apply_injection_policy(self, client_module, injection_policy, return_tuple, replace_with_kernel_inject)
    159                                 replace_with_kernel_inject=False):
    160 
--> 161         replace_transformer_layer(client_module,
    162                                   self.module,
    163                                   policy=injection_policy,

~/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed/module_inject/replace_module.py in replace_transformer_layer(orig_layer_impl, model, policy, micro_batch_size, config, seed, hidden_size, num_attention_heads, mp_size, mp_group, preln, fp16, local_rank, stochastic_mode, training, quantize, quantize_settings, return_tuple, replace_with_kernel_inject, linear_layer_setting)
    462         return new_module
    463 
--> 464     return replace_module(model=model,
    465                           orig_class=orig_layer_impl,
    466                           replace_fn=replace_fn,

~/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed/module_inject/replace_module.py in replace_module(model, orig_class, replace_fn, _replace_policy)
    559         "You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"
    560 
--> 561     replaced_module, _ = _replace_module(model, policy)
    562     return replaced_module
    563 

~/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed/module_inject/replace_module.py in _replace_module(model, policies, layer_id)
    581             layer_id += 1
    582         else:
--> 583             _, layer_id = _replace_module(child, policies, layer_id=layer_id)
    584 
    585     return model, layer_id

~/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed/module_inject/replace_module.py in _replace_module(model, policies, layer_id)
    581             layer_id += 1
    582         else:
--> 583             _, layer_id = _replace_module(child, policies, layer_id=layer_id)
    584 
    585     return model, layer_id

~/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed/module_inject/replace_module.py in _replace_module(model, policies, layer_id)
    576                 model,
    577                 name,
--> 578                 policies[child.__class__][0](child,
    579                                              policies[child.__class__][-1],
    580                                              layer_id))

~/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed/module_inject/replace_module.py in replace_fn(child, _policy, layer_id)
    451             # copy relevant state from child -> new module
    452             if replace_with_kernel_inject:
--> 453                 new_module = replace_with_policy(child,
    454                                                  _policy,
    455                                                  inference=True,

~/anaconda3/envs/dev/lib/python3.8/site-packages/deepspeed/module_inject/replace_module.py in replace_with_policy(child, policy_cls, inference, preln, layer_id)
    288 
    289             attn_block.attn_ow.data = mp_replace.copy(attn_block.attn_ow.data, dense_w)
--> 290             attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob.data, dense_b)
    291 
    292             mpl_block = new_module.mlp

~/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py in __setattr__(self, name, value)
   1149         elif params is not None and name in params:
   1150             if value is not None:
-> 1151                 raise TypeError("cannot assign '{}' as parameter '{}' "
   1152                                 "(torch.nn.Parameter or None expected)"
   1153                                 .format(torch.typename(value), name))

TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'attn_ob' (torch.nn.Parameter or None expected)

@RezaYazdaniAminabadi
Copy link
Contributor

Hi @joehoover

Thanks for trying this out. Great performance results, I am happy to see such good improvement.
Yes, please open these issues in another thread, since they are not related to this one. I am gonna close this issue, then.
Best,
Reza

@joehoover
Copy link

joehoover commented Jan 11, 2022

@RezaYazdaniAminabadi, thanks so much for putting this together!

Quick question: should I expect DeepSpeed inference to add memory overhead? I've been using a 16GB t4 for inference dev and I can fit the FP16 GPT-J weights on that device with room to spare. However, when I initialize DeepSpeed inference, I'm running out of VRAM.

Just want to make sure I'm not making a mistake somewhere.

@oborchers
Copy link
Author

oborchers commented Jan 12, 2022

@joehoover yes. You need to load on CPU and then let deepspeed do the conversion which moves it to GPU.

@RezaYazdaniAminabadi do you think there’s a way to limit memory consumption during injectio?

@Kkkassini
Copy link

Same problem, getting OOM using single 16GB t4 for inference. Is there a sample script for this?

@TiesdeKok
Copy link

TiesdeKok commented Jan 24, 2022

Has anyone here been able to get the GPT-J inference kernel to work on more than one GPU?

There might be a bug in the code causing it to fail with more than one GPU, see issue: #1719

@RezaYazdaniAminabadi
Copy link
Contributor

@Kkkassini, can you try creating pipeline and set the device after deepsepeed.init_inference. Please use this script as an example.

@RezaYazdaniAminabadi
Copy link
Contributor

@TiesdeKok, there has been some changes on the injection that might have caused this issue. I look into this and try to fix it soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

9 participants