Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alibi Tensor Parallel Fix #244

Merged
merged 6 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,12 @@ def get_slopes_power_of_2(n):
slopes = torch.Tensor(get_slopes(num_attention_heads))
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand(
num_attention_heads, -1, -1)

#Select the part of the tensor that corresponds to our tensor parallel index.
tp_world_size = mpu.get_tensor_model_parallel_world_size()
tp_index = mpu.get_tensor_model_parallel_rank()
alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index]
num_attention_head_per_partition = mpu.divide(num_attention_heads, tp_world_size)
alibi = alibi[tp_index * num_attention_head_per_partition: (tp_index + 1) * num_attention_head_per_partition]

Personally I always find reshape to be weird magic.
We can probably do something more efficient by only computing what we need, but let's do that for now since this is just done at init and should be short.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disagree, I think the reshape is clearer, I'll keep it as is.


alibi = alibi.repeat(batch_size, 1, 1)
return alibi

Expand Down
219 changes: 219 additions & 0 deletions tests/test_tensor_paralell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
from gc import get_referents
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some unused imports

import sys, os
dir = os.path.abspath(os.path.join(os.path.dirname(__file__),os.path.pardir))
sys.path.append(dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm so confused, our tests don't have that. I'm guessing you haven't installed the repo via pip. Please remove it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should be installed via pip? Megatron-LM and its derivatives aren't installable.

Copy link
Contributor

@stas00 stas00 Feb 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we are already adding the root dir automatically here for all tests to enjoy.

git_repo_path = abspath(join(dirname(dirname(__file__))))
sys.path.insert(1, git_repo_path)

so it's probably just redundant. and that's why Thomas suggested to remove it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a more modern readable version would be:

import sys
from pathlib import Path
git_repo_path = Path(__file__).resolve().parents[1]
sys.path.insert(1, str(git_repo_path))

I wasn't sure if it's one parent up though or 2 you wanted.

but it's fine as it is as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I'll just remove it all together. Was just for quicker iteration without going through pytest

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in such cases then it's easier to use:

PYTHONPATH=`pwd` tests/test.py

or something like that. :)


import unittest
from random import randint
from unittest.mock import patch

import deepspeed
import torch
import logging

import pytest
from megatron import initialize_megatron, get_args, get_tokenizer, global_vars
from megatron.testing_utils import TestCasePlus, mockenv_context
from megatron.training import setup_model_and_optimizer
from megatron.mpu.mappings import gather_from_tensor_model_parallel_region
from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe
from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe
import multiprocessing as mp
from multiprocessing import Pool
from megatron.checkpointing import save_checkpoint

from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model

def flatten_arguments(args):
"""
Converts dictionary argument to a list.

Note: we add "IGNORED" at the beginning as this value is ignored by the argparser

Example: {"arg1": "value1", "arg2": "value2"} -> ["IGNORED", "arg1", "value1", "arg2", "value2"]
"""
return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's duplicated code. Ideally if you think it's helpful you can add it to testing_utils and import it directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



def equal_vectors(tensor1, tensor2, dim=-1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

"""View tensor1 and tensor2 as a list of vectors, and compute equality"""
return torch.linalg.norm(tensor1 - tensor2, dim=dim) == 0


class MyTestCase(TestCasePlus):
def get_default_args(self, tp_size):
"""return a dictionary with key as argument name and value as additional arguments"""
return {
# GPT_ARGS
"--num-layers": "2",
"--hidden-size": "128",
"--num-attention-heads": "4",
"--seq-length": "256",
"--max-position-embeddings": "256",
"--micro-batch-size": "4",
"--global-batch-size": "8",
"--lr-decay-iters": "320000",
"--lr-decay-style": "cosine",
"--lr": "0.00015",
"--min-lr": "1.0e-5",
"--train-iters": "5000",
"--tokenizer-type": "PretrainedFromHF",
"--tokenizer-name-or-path": "gpt2",
"--data-impl": "mmap",
"--split": "949,50,1",
"--distributed-backend": "nccl",
"--weight-decay": "1e-2",
"--clip-grad": "1.0",
"--lr-warmup-fraction": ".01",
"--fp16": "",

"--attention-dropout": "0",
"--hidden-dropout": "0",

# ALIBI:
"--position-embedding-type":"alibi",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not put this as default? Essentially the way I see it is default is basically common config people would use. If you strongly disagree let me know.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, fixed


# OUTPUT_ARGS
"--log-interval": "10",
"--save-interval": "500",
"--eval-interval": "100",
"--eval-iters": "10",
"--checkpoint-activations": "",

# paralell args
"--tensor-model-parallel-size":str(tp_size),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, you can

args = get_default_args()
args["--tensor-model-parallel-size"] = str(tp_size)


#ds args
"--deepspeed": "",
"--deepspeed_config":f"{self.test_file_dir_str}/ds_config.json",
"--zero-stage": "1",
"--deepspeed-activation-checkpointing": ""
# DATA_ARGS
}

def setUp(self) -> None:
super().setUp()

# We reset all global variables
global_vars._GLOBAL_ARGS = None
global_vars._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
global_vars._GLOBAL_TOKENIZER = None
global_vars._GLOBAL_TENSORBOARD_WRITER = None
global_vars._GLOBAL_ADLR_AUTORESUME = None
global_vars._GLOBAL_TIMERS = None


def infer_model(args):
tp_index, tp_size, command_args, token_ids, save, load = args
dist_env_1_gpu = dict(
MASTER_ADDR="localhost", MASTER_PORT="9994", RANK=str(tp_index), LOCAL_RANK=str(tp_index), WORLD_SIZE=str(tp_size)
)
logging.getLogger().critical("Process: starting")

#Hack
import megatron.initialize as init
init.git_ds_info = lambda: None

with patch('sys.argv', flatten_arguments(command_args)):
with mockenv_context(**dist_env_1_gpu):

def create_model_inputs(tokens):
args = get_args()

attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=False)

return (tokens, position_ids, attention_mask), (tokens, loss_mask)

deepspeed.init_distributed()
initialize_megatron()
args = get_args()

args.vocab_size = args.padded_vocab_size = 1024

tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]
if load is not None:
# Hack (same as in eval_harness/evaluate.py)
# Loading pipelined models in deepspeed with different TP than it was originally trained on fails
# due to a sanity check, that makes sure that all state_dicts that we merge contains attention layers.
# This, however, is not true for pipelining when we will merge the state_dict for the embeddings which
# which does not contain these attention-specific keys.
#
# Deepspeed does however manage to load the model if we just turn off this sanity check.
deepspeed.runtime.state_dict_factory.MegatronSDLoader.sanity_check = lambda self, ckpt_file_name: None

zero_enabled = model._config.zero_enabled
model._config.zero_enabled = False
_, _ = model.load_checkpoint(load, load_optimizer_states=False, load_lr_scheduler_states=False, load_module_only=True)
model._config.zero_enabled = zero_enabled

if token_ids is None:
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size
else:
token_ids = torch.tensor(token_ids)


model.micro_batches = 1
model.set_batch_fn(create_model_inputs)
# process batch
input_batch = get_gpt_batch_pipe({"text": token_ids})[0]

# get a modified version of the first batch, we change a specific index
changed_index = randint(0, args.seq_length - 2)
input_token_ids_changed = input_batch[0].clone()
# We increment the token_id by one for that index in order to artificially change the sequence.
input_token_ids_changed[:, changed_index] = \
(input_token_ids_changed[:,changed_index] + 1) % args.padded_vocab_size

#output = model(*input_batch)
output = model.eval_batch(iter([token_ids]), compute_loss = False, reduce_output = None)[0]

output = gather_from_tensor_model_parallel_region(output)[..., :tokenizer.vocab_size]

if save != None:
args.save = save
save_checkpoint(0, [model], None, None)

return (output[0].detach().cpu().numpy(), token_ids.detach().cpu().numpy())

def test_cross(self):

mp.set_start_method('spawn', force=True)
cp_dir = self.get_auto_remove_tmp_dir()

command_args = self.get_default_args(tp_size = 1)
pool = Pool(1)
result = pool.map(MyTestCase.infer_model, [((0, 1, command_args, None, cp_dir, None))])
pool.close()
pool.join()

output, tokens = result[0]
logging.getLogger().critical("First done!")

command_args = self.get_default_args(tp_size = 2)
pool = Pool(2)
result = pool.map(MyTestCase.infer_model, [((0, 2, command_args, tokens, None, cp_dir)), ((1, 2, command_args, tokens, None, cp_dir))])
pool.close()
pool.join()

output2, tokens = result[0]

logging.getLogger().critical(output-output2)
import numpy as np
self.assertTrue(np.allclose(output,output2, atol=5e-3, rtol=0), "Different results when running with TP=1 and TP=2")

if __name__ == '__main__':
unittest.main()