Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-enable GPT-J unit tests and refactor inference tests #3618

Merged
merged 36 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
148af32
gpt-j model was renamed on HF
mrwyattii May 26, 2023
51fa3e0
Merge branch 'master' into mrwyattii/fix-broken-gptj-tests
loadams Jun 7, 2023
7d3bcbf
Merge branch 'master' into mrwyattii/fix-broken-gptj-tests
loadams Jun 12, 2023
e5a755c
BAd merge in GitHubGUI
loadams Jun 12, 2023
df1859d
zero++ tutorial PR (#3783)
HeyangQin Jun 21, 2023
d81a6ad
[Fix] _conv_flops_compute when padding is a str and stride=1 (#3169)
zhiruiluo Jun 21, 2023
a8c182a
fix interpolate flops compute (#3782)
cli99 Jun 22, 2023
c4c442f
use `Flops Profiler` to test `model.generate()` (#2515)
CaffreyR Jun 22, 2023
7c6e3ab
Merge branch 'master' into mrwyattii/fix-broken-gptj-tests
mrwyattii Jun 22, 2023
fc9e1ee
revert PR #3611 (#3786)
jeffra Jun 22, 2023
40045dc
bump to 0.9.6
jeffra Jun 22, 2023
49a0a1b
ZeRO++ chinese blog (#3793)
HeyangQin Jun 23, 2023
2c62cb4
remove staging trigger (#3792)
jeffra Jun 23, 2023
4dc65f7
DeepSpeed-Triton for Inference (#3748)
stephen-youn Jun 23, 2023
e1119d8
ZeRO++ (#3784)
HeyangQin Jun 23, 2023
01b843a
adding zero++ to navigation panel of deepspeed.ai (#3796)
HeyangQin Jun 23, 2023
319b64e
Add ZeRO++ Japanese blog (#3797)
tohtana Jun 23, 2023
b4a2c0a
Bug Fixes for autotuner and flops profiler (#1880)
cli99 Jun 23, 2023
b7e1010
Missing strided copy for gated MLP (#3788)
cmikeh2 Jun 23, 2023
e5b1ead
Requires grad checking. (#3789)
jomayeri Jun 23, 2023
9c756cf
bump to 0.10.0
jeffra Jun 23, 2023
babd883
update how we generate inference model/task combinations to reduce nu…
mrwyattii Jun 23, 2023
8163d8c
consolidate fixtures so they can be reused
mrwyattii Jun 23, 2023
8011778
Merge branch 'master' into mrwyattii/fix-broken-gptj-tests
mrwyattii Jun 23, 2023
ba44a08
resolve changes from master merge
mrwyattii Jun 23, 2023
a204edc
Fix Bug in transform.cu (#3534)
rraminen Jun 23, 2023
bba09ac
Merge branch 'master' into mrwyattii/fix-broken-gptj-tests
loadams Jun 23, 2023
281e150
Merge branch 'master' into mrwyattii/fix-broken-gptj-tests
mrwyattii Jun 26, 2023
81af369
Update test_inference.py
mrwyattii Jun 26, 2023
246d41c
formatting
mrwyattii Jun 26, 2023
939cd97
Merge branch 'master' into mrwyattii/fix-broken-gptj-tests
mrwyattii Jun 27, 2023
acde8ca
fix injection policy test
mrwyattii Jun 27, 2023
c7a4cc8
revert moving fixtures to separate file
mrwyattii Jun 27, 2023
a8db16b
remove init
mrwyattii Jun 27, 2023
bf2e650
change profiling model to one that support CUDA Graph
mrwyattii Jun 28, 2023
56d970b
Merge branch 'master' into mrwyattii/fix-broken-gptj-tests
mrwyattii Jun 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[flake8]
ignore = E,F403,F405,F541,F841,W
select = E9,F,W6
per-file-ignores =
__init__.py:F401
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401']
args: ['--config=.flake8']

- repo: local
hooks:
Expand Down
1 change: 1 addition & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ markers =
inference_ops:Individual inference operator tests
seq_inference:Inference model tests to run sequentially
nightly:Tests that should be run nightly
world_size:Change world size of individual tests in a class
165 changes: 84 additions & 81 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,75 +49,57 @@
"gpt2",
"distilgpt2",
"Norod78/hebrew-bad_wiki-gpt_neo-tiny",
"EleutherAI/gpt-j-6B", # bring back this model as we did not catch an error before by merging some changes! TODO: we need to fix the OOM issue later!
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m-deduped",
"bigscience/bloom-560m",
]
_opt_models = [
"facebook/opt-125m", # 125m, 1.7B, ..., 175B variants have the same model architecture.
"facebook/opt-350m", # 350m applies layer norm after attention layer which is different than other variants.
]
_all_models = HfApi().list_models()

test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models)
test_tasks = [
_test_models = set(_bert_models + _roberta_models + _gpt_models + _opt_models)
_test_tasks = [
"fill-mask", "question-answering", "text-classification", "token-classification", "text-generation",
"text2text-generation", "summarization", "translation"
]
pytest.all_models = {task: [m.modelId for m in _all_models if m.pipeline_tag == task] for task in test_tasks}

_model_w_tasks = itertools.product(*[test_models, test_tasks])


def _valid_model_task(model_task):
m, t = model_task
return m in pytest.all_models[t]


pytest.models_w_tasks = list(filter(_valid_model_task, _model_w_tasks))
pytest.mt_names = [f"{m}-{t}" for m, t in pytest.models_w_tasks]
"""
These fixtures iterate all combinations of tasks and models, dtype, & cuda_graph
"""


@pytest.fixture(params=pytest.models_w_tasks, ids=pytest.mt_names)
def model_w_task(request):
return request.param


@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
def dtype(request):
return request.param

# Get a list of all models and mapping from task to supported models
_hf_models = HfApi().list_models()
_hf_model_names = [m.modelId for m in _hf_models]
_hf_task_to_models = {task: [m.modelId for m in _hf_models if m.pipeline_tag == task] for task in _test_tasks}

@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
def enable_cuda_graph(request):
return request.param
# Get all combinations of task:model to test
_model_w_tasks = [(m, t) for m, t in itertools.product(*[_test_models, _test_tasks]) if m in _hf_task_to_models[t]]

# Assign to pytest variables for testing
pytest.model_w_tasks = _model_w_tasks
pytest.mt_names = [f"{m}-{t}" for m, t in pytest.model_w_tasks]

@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
def enable_triton(request):
return request.param

@pytest.fixture(scope="module", autouse=True)
def verify_models():
# Verify all test models are registered in HF
_test_models_not_found = [m for m in _test_models if m not in _hf_model_names]
if _test_models_not_found:
pytest.fail(f"Model(s) not found in HuggingFace: {_test_models_not_found}")

"""
This fixture will validate the configuration
"""
# Verify all models are assigned to at least one task
_models_to_be_tested = set(m for m, t in _model_w_tasks)
_missing_task_models = _models_to_be_tested.difference(_test_models)
if _missing_task_models:
pytest.fail(f"Model(s) do not have an assigned task: {_missing_task_models}")


# Fixture to add skips for certain configurations
@pytest.fixture()
def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_triton):
def invalid_test(model_w_task, dtype, enable_cuda_graph, enable_triton):
model, task = model_w_task
msg = ""
if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
loadams marked this conversation as resolved.
Show resolved Hide resolved
msg = "DS inference injection doesn't work well on older torch versions"
elif model not in pytest.all_models[task]:
msg = f"Not a valid model / task combination: {model} / {task}"
elif enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
if enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
msg = "CUDA not detected, cannot use CUDA Graph"
elif enable_cuda_graph and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"):
msg = "CUDA Graph is only available in torch versions >= 1.10"
elif "gpt-j-6B" in model:
elif "gpt-j-6b" in model:
if dtype != torch.half:
msg = f"Not enough GPU memory to run {model} with dtype {dtype}"
elif enable_cuda_graph:
Expand All @@ -139,10 +121,30 @@ def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_tri
return msg


"""
These fixtures can be used to customize the query, inference args, and assert
statement for each combination of model /task
"""
""" Fixtures for inference config """


@pytest.fixture(params=pytest.model_w_tasks, ids=pytest.mt_names)
def model_w_task(request):
return request.param


@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
def dtype(request):
return request.param


@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
def enable_cuda_graph(request):
return request.param


@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
def enable_triton(request):
return request.param


""" Fixtures for running query """


@pytest.fixture
Expand Down Expand Up @@ -178,14 +180,17 @@ def query(model_w_task):
def inf_kwargs(model_w_task):
model, task = model_w_task
if task == "text-generation":
if model == "EleutherAI/gpt-j-6B":
if model == "EleutherAI/gpt-j-6b":
# This model on V100 is hitting memory problems that limit the number of output tokens
return {"do_sample": False, "max_length": 12}
return {"do_sample": False, "max_length": 20}
else:
return {}


""" Assertion fixture for verifying model outputs """


def fill_mask_assert(x, y):
return set(res["token_str"] for res in x) == set(res["token_str"] for res in y)

Expand Down Expand Up @@ -237,6 +242,7 @@ def assert_fn(model_w_task):
return assert_fn


# Used to verify DeepSpeed kernel injection worked with a model
def check_injection(model):

def verify_injection(module):
Expand All @@ -251,27 +257,24 @@ def verify_injection(module):
verify_injection(model)


"""
Tests
"""


@pytest.mark.inference
class TestModelTask(DistributedTest):
world_size = 1

def test(self,
model_w_task,
dtype,
enable_cuda_graph,
enable_triton,
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
perf_meas=True):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
def test(
self,
model_w_task,
dtype,
enable_cuda_graph,
enable_triton,
query,
inf_kwargs,
assert_fn,
invalid_test,
perf_meas=True,
):
if invalid_test:
pytest.skip(invalid_test)

model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand Down Expand Up @@ -338,10 +341,10 @@ def test(self,
@pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-neo-1.3B", "text-generation"),
("EleutherAI/gpt-neox-20b", "text-generation"),
("bigscience/bloom-3b", "text-generation"),
("EleutherAI/gpt-j-6B", "text-generation")],
("EleutherAI/gpt-j-6b", "text-generation")],
ids=["gpt-neo", "gpt-neox", "bloom", "gpt-j"])
class TestMPSize(DistributedTest):
world_size = 4
world_size = 2

def test(
self,
Expand All @@ -350,10 +353,10 @@ def test(
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
invalid_test,
):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
if invalid_test:
pytest.skip(invalid_test)

model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand Down Expand Up @@ -402,12 +405,12 @@ def test(
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
invalid_test,
dtype,
enable_cuda_graph,
):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
if invalid_test:
pytest.skip(invalid_test)

model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand Down Expand Up @@ -452,12 +455,12 @@ def test(
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
invalid_test,
dtype,
enable_cuda_graph,
):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
if invalid_test:
pytest.skip(invalid_test)

model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand All @@ -483,7 +486,7 @@ def test(
"model_family, model_name",
(
["gpt2", "EleutherAI/gpt-neo-2.7B"],
["gpt2", "EleutherAI/gpt-j-6B"],
["gpt2", "EleutherAI/gpt-j-6b"],
["gpt2", "gpt2-xl"],
),
)
Expand All @@ -503,7 +506,7 @@ def test(self, model_family, model_name, task):
dtype = torch.float
task_dict = lm_eval.tasks.get_task_dict([task])

if 'gpt-j-6B' in model_name:
if 'gpt-j-6b' in model_name:
dtype = torch.half
lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}",
{"device": "cpu"})
Expand Down
41 changes: 8 additions & 33 deletions tests/unit/inference/test_model_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,18 @@
from deepspeed.accelerator import get_accelerator


@pytest.fixture
def query(model, task):
if task == "text-generation":
return "DeepSpeed is"
elif task == "fill-mask":
if "roberta" in model:
return "I am a <mask> model"
else:
return "I am a [MASK] model"
else:
raise NotImplementedError


@pytest.fixture
def inf_kwargs(task):
if task == "text-generation":
return {"do_sample": False, "min_length": 50, "max_length": 50}
else:
return {}


@pytest.mark.inference
@pytest.mark.parametrize("model,task", [
("bert-base-cased", "fill-mask"),
("roberta-base", "fill-mask"),
("gpt2", "text-generation"),
("facebook/opt-125m", "text-generation"),
("bigscience/bloom-560m", "text-generation"),
])
@pytest.mark.parametrize("cuda_graphs", [True, False])
@pytest.mark.parametrize("use_cuda_events", [True, False])
@pytest.mark.parametrize("enable_cuda_graph", [True, False])
class TestModelProfiling(DistributedTest):
world_size = 1

def test(self, model, task, query, inf_kwargs, cuda_graphs, use_cuda_events, dtype=torch.float16):
if cuda_graphs and "bert" not in model:
pytest.skip(f"CUDA Graph not supported for {model}")
def test(self, enable_cuda_graph, use_cuda_events):
task = "text-generation"
model = "bigscience/bloom-560m"
dtype = torch.float16
query = "DeepSpeed is"
inf_kwargs = {"do_sample": False, "min_length": 50, "max_length": 50}

local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
Expand All @@ -59,7 +34,7 @@ def test(self, model, task, query, inf_kwargs, cuda_graphs, use_cuda_events, dty
dtype=dtype,
mp_size=world_size,
replace_with_kernel_inject=True,
enable_cuda_graph=cuda_graphs)
enable_cuda_graph=enable_cuda_graph)
pipe.model.profile_model_time(use_cuda_events=use_cuda_events)

e2e_times = []
Expand Down