Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose GLU activations as arguments #69

Merged
merged 15 commits into from
Aug 22, 2021
Merged

Expose GLU activations as arguments #69

merged 15 commits into from
Aug 22, 2021

Conversation

jaketae
Copy link
Member

@jaketae jaketae commented Aug 18, 2021

This PR exposes improved GLU activation functions (#47) as arguments to the main script. The added flag is

... --glu-activation geglu

@jaketae jaketae requested a review from thomasw21 August 18, 2021 18:04
@jaketae
Copy link
Member Author

jaketae commented Aug 18, 2021

@thomasw21 Some additional notes:

  1. I saw that you used enums for the rotary embedding argument, but I decided to go with raw strings + getattr() combo to load the functions, instead of writing a lot of if statements. I hope this is fine.
  2. I saw a comment that said JIT does not play well with bp16. By default, all the GLU activation functions are JIT'ed. If this becomes a problem, I'll add more code that exposes them as normal torch functions.

Thanks for the review!

Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Small nit, let me know what you think:

  • also could you add a test? Maybe one that just run a model with a glu config so we know it doesn't throw.

About your comments:

  • I dislike getattr as I generally consider it bad practice, but really depends on the crowd I guess. I've seen some occurence of it in the codebase so why not?
  • Concerning that ... I don't know maybe you can test it out?

megatron/model/transformer.py Outdated Show resolved Hide resolved
megatron/model/transformer.py Outdated Show resolved Hide resolved
megatron/arguments.py Outdated Show resolved Hide resolved
@jaketae
Copy link
Member Author

jaketae commented Aug 18, 2021

@thomasw21 Thanks for the review!

The activation function themselves are already tested for in test/test_activations.py, but all tests were assuming FP32, so I added a basic test that checks for BF16, which failed.

Are we running experiments in torch.bfloat16? On my local, even using the built-in torch.nn.functiona.gelu() with a BF16 tensor raises this error:

>>> import torch
>>> x = torch.randn(8, 100, 768).to(torch.bfloat16)
>>> x.dtype
torch.bfloat16
>>> torch.nn.functional.gelu(x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/jaketae/Documents/Dev/GitHub/Megatron-DeepSpeed/venv/lib/python3.7/site-packages/torch/nn/functional.py", line 1555, in gelu
    return torch._C._nn.gelu(input)
RuntimeError: "GeluKernelImpl" not implemented for 'BFloat16'
>>> torch.__version__
'1.9.0'

The trace seems to suggest that there's no kernel for F.gelu() that handles BF16 tensors. So I would be very surprised if we ever run anything in BF16, but maybe I'm missing something here.

@thomasw21
Copy link
Member

thomasw21 commented Aug 19, 2021

So I'm not too familiar with bf16. From what I understand it's only available for A100 (cpu support seems to be very basic). Concerning our current experiments, we're running fp16 on all experiments right now.

JZ has very few A100 and are more experimental, so there's probably no chance we're running bf16 there. However it doesn't mean we never will. Let's make a best effort to support it.

Also concerning tests, maybe we should start discussing what's the general testing strategy. I usually enjoy general tests that test an entire pipeline. The reason it allows to think of the lib as an API, and launch script via CLI. So we'd want expected commands to work as expected. But I'm open to UT. Let's discuss.

@jaketae
Copy link
Member Author

jaketae commented Aug 21, 2021

@thomasw21 I brushed up on the things that were mentioned (e.g. default, GLU_ACTIVATIONS). As for BF16 and JIT, I think I was getting ahead of myself. Since we're using FP16 on V100s, I don't think JIT will be an issue. If it does, we can easily fix it by removing JIT.

I agree that an exhaustive high-level test that goes through the entire pipeline would be helpful. I also think @stas00 might have some ideas in mind. Maybe we can open a separate issue and continue the discussion there.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, @jaketae

Probably add --glu-activation=geglu to the main training test - at the moment it just tests that it can just run and validates logs where possible - i.e. no need for a new training test, as you're already testing the functionality elsewhere.

So here it'd be something like this: 8088c0f

@stas00
Copy link
Contributor

stas00 commented Aug 22, 2021

indeed, we don't need to worry about bf16 for JZ, but we could use it on A100 if we get access.

To check for bf16 support, please see:
huggingface/transformers@26eb566

I have rtx-3090 so if you want me to test something with bf16 please let me know. I only have 1 card though.

The jit+bf16 check can be done in a test

@jaketae
Copy link
Member Author

jaketae commented Aug 22, 2021

@stas00 Thanks for the review.

  1. I added is_torch_bf16_available() from the transformers repo. This function is called in test_bf16_jit() to unit test GLU activations in an environment where BF16 is available. So I assume if you run python -m unittest tests/test_activations.py on your local RTX 3090, it will properly test if things work as expected.
  2. I added the activation function flag in test_training_all() as suggested.

Let me know if there's anything that needs to be fixed!

megatron/testing_utils.py Outdated Show resolved Hide resolved
tests/test_activations.py Outdated Show resolved Hide resolved
@stas00
Copy link
Contributor

stas00 commented Aug 22, 2021

Here you go - on rtx-3090

pytest tests/test_activations.py  -k test_bf16_jit
====================================================================== test session starts ======================================================================
platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1
rootdir: /mnt/nvme1/code/huggingface, configfile: pytest.ini
plugins: dash-1.20.0, forked-1.3.0, xdist-2.3.0, instafail-0.4.2
collected 6 items / 5 deselected / 1 selected                                                                                                                   

tests/test_activations.py F

_________________________________________________________________ TestActivations.test_bf16_jit _________________________________________________________________

self = <test_activations.TestActivations testMethod=test_bf16_jit>

    @require_torch_bf16
    def test_bf16_jit(self):
        x_bf16 = self.x.to(torch.bfloat16)
        for activation_fn in GLU_ACTIVATIONS.values():
>           output = activation_fn(x_bf16)

tests/test_activations.py:48:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = RecursiveScriptModule(original_name=GEGLU)
input = (tensor([[[ 1.9297,  1.4844,  0.9023,  ...,  1.6797,  1.2812,  1.2969],
         [ 0.6094,  1.3359, -0.2314,  ..., -0....0469,  0.2852],
         [ 0.1680, -0.3379, -1.1484,  ...,  1.2500, -0.6367, -0.5156]]],
       dtype=torch.bfloat16),)
kwargs = {}, forward_call = <torch._C.ScriptMethod object at 0x7f3ed99cd270>

    def _call_impl(self, *input, **kwargs):
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        # If we don't have any hooks, we want to skip the rest of the logic in
        # this function, and just call forward.
        if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
                or _global_forward_hooks or _global_forward_pre_hooks):
>           return forward_call(*input, **kwargs)
E           RuntimeError: The following operation failed in the TorchScript interpreter.
E           Traceback of TorchScript (most recent call last):
E             File "/mnt/nvme1/code/huggingface/Megatron-DeepSpeed-master/megatron/model/glu_activations.py", line 14, in forward
E                   # dim=-1 breaks in jit for pt<1.10
E                   x1, x2 = x.chunk(2, dim=(x.ndim - 1))
E                   return x1 * self.activation_fn(x2)
E                               ~~~~~~~~~~~~~~~~~~ <--- HERE
E             File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/site-packages/torch/nn/functional.py", line 1555, in gelu
E               if has_torch_function_unary(input):
E                   return handle_torch_function(gelu, (input,), input)
E               return torch._C._nn.gelu(input)
E                      ~~~~~~~~~~~~~~~~~ <--- HERE
E           RuntimeError: "GeluKernelImpl" not implemented for 'BFloat16'

/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: RuntimeError
                                                                                                                                                          [100%]
==================================================================== short test summary info ====================================================================
FAILED tests/test_activations.py::TestActivations::test_bf16_jit - RuntimeError: The following operation failed in the TorchScript interpreter.
================================================================ 1 failed, 5 deselected in 2.26s ================================================================

@jaketae
Copy link
Member Author

jaketae commented Aug 22, 2021

RuntimeError: "GeluKernelImpl" not implemented for 'BFloat16'

I see you got the same error as the one I got on my local. The trace seems to suggest that we can't use BF16 with JIT due to some missing kernel implementations. Here are my thoughts:

  1. We could ignore this for now since we're on V100s anyway? (of course, assuming that test_training_all() passes)
  2. I'll push some non-JIT'ed version of the activation functions to glu_activations.py so that we can use them if need be, once we have access to A100 nodes that can utilize BF16.

What do you think?

@stas00
Copy link
Contributor

stas00 commented Aug 22, 2021

This is unrelated to JIT, it's the same error if you remove jit:

-geglu = torch.jit.script(GEGLU())
+geglu = GEGLU()
E       RuntimeError: "GeluKernelImpl" not implemented for 'BFloat16'

It says it has no such kernel implemented for this dtype.

@stas00
Copy link
Contributor

stas00 commented Aug 22, 2021

We definitely don't have to do anything about it for now. We will cross that bridge when we come to it.

Until then probably comment out the whole bf16 test, since we might want it down the road when torch implements bf16 support for those kernels. And it's good to have it there for other bf16 tests as an example.

@jaketae
Copy link
Member Author

jaketae commented Aug 22, 2021

I see, thanks for the clarification. I see that it's just a BF16 issue. I'll comment out the test for now. Thanks!

jaketae and others added 2 commits August 23, 2021 03:51
uncomment in the future when torch supports gelu kernels for bf16
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good to go now

megatron/arguments.py Outdated Show resolved Hide resolved
@jaketae jaketae merged commit b5a029d into main Aug 22, 2021
@jaketae jaketae deleted the activation-args branch August 22, 2021 19:23
@stas00 stas00 mentioned this pull request Aug 22, 2021
@jaketae
Copy link
Member Author

jaketae commented Aug 22, 2021

@stas00 Shouldn't have merged this too quickly, thanks for the quick fix!

@stas00
Copy link
Contributor

stas00 commented Aug 22, 2021

All is good. It was my unvalidated suggestion that broke it.

Once we get a CI, it'll be much easier.

adammoody pushed a commit to adammoody/Megatron-DeepSpeed that referenced this pull request Oct 27, 2022
* add checkpoint measurement

* Update CODEOWNERS

* add TFLOP per sec support

* remove duplicate tflops calculation

* remove unnecessary comment

* remove comments

* remove comment
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants