Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow args to be optional in deepspeed.initialize #825

Merged
merged 7 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def _parse_version(version_str):
sys.modules['deepspeed.pt.loss_scaler'] = deepspeed.runtime.fp16.loss_scaler


def initialize(args,
model,
def initialize(args=None,
model=None,
optimizer=None,
model_parameters=None,
training_data=None,
Expand All @@ -62,8 +62,7 @@ def initialize(args,
"""Initialize the DeepSpeed Engine.
Arguments:
args: a dictionary containing local_rank and deepspeed_config
file location
args: an object containing local_rank and deepspeed_config fields. This is optional if `config_params` is passed.
model: Required: nn.module class before apply any wrappers
Expand All @@ -88,6 +87,9 @@ def initialize(args,
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
config_params: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config
as a dictionary instead.
Returns:
A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler``
Expand All @@ -108,6 +110,8 @@ def initialize(args,
__git_branch__),
ranks=[0])

assert model is not None, "deepspeed.initialize requires a model"

if not isinstance(model, PipelineModule):
engine = DeepSpeedEngine(args=args,
model=model,
Expand Down
20 changes: 10 additions & 10 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,10 @@ def _configure_with_arguments(self, args, mpu):
# After the distributed backend is initialized we are guaranteed the LOCAL_RANK
# environment variable is set. We must align args.local_rank to this value for
# backwards compatability with scripts relying on [args|self].local_rank containing
# the correct local rank info.
args.local_rank = int(os.environ['LOCAL_RANK'])
self.local_rank = args.local_rank
# the correct local rank info. _do_args_sanity_check will ensure this is the case.
self.local_rank = int(os.environ['LOCAL_RANK'])
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(args, 'local_rank'):
args.local_rank = self.local_rank

config_file = args.deepspeed_config if hasattr(args,
'deepspeed_config') else None
Expand All @@ -513,15 +514,14 @@ def _do_args_sanity_check(self, args):
assert args.deepspeed_config is None, "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args.deepspeed_config = args.deepscale_config

local_rank_err = "DeepSpeed requires a command line parameter of --local_rank [int] and/or setting the LOCAL_RANK environment variable."
if hasattr(args, 'local_rank'):
assert type(args.local_rank) == int, local_rank_err
if "LOCAL_RANK" in os.environ and args.local_rank >= 0:
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
assert "LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment variable, it is set by the deepspeed launcher, " \
"deepspeed.init_distributed, or the torch.distributed launcher. If using a different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed."
if hasattr(args, 'local_rank') and args.local_rank != None:
assert isinstance(args.local_rank, int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}"
if args.local_rank >= 0:
env_local_rank = int(os.environ.get("LOCAL_RANK"))
assert env_local_rank == args.local_rank, \
f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}."
else:
assert "LOCAL_RANK" in os.environ, local_rank_err

if self.config_params is None:
assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
Expand Down
80 changes: 80 additions & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,83 @@ def _helper():
model.step()

_helper()


def test_none_args(tmpdir):
config_dict = {
"train_batch_size": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
}
}

@distributed_test(world_size=1)
def _helper():
model = SimpleModel(hidden_dim=10)
model, _, _, _ = deepspeed.initialize(args=None, model=model, config_params=config_dict)
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=10,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])

_helper()


def test_no_args(tmpdir):
config_dict = {
"train_batch_size": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
}
}

@distributed_test(world_size=1)
def _helper():
model = SimpleModel(hidden_dim=10)
model, _, _, _ = deepspeed.initialize(model=model, config_params=config_dict)
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=10,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])

_helper()


def test_no_model(tmpdir):
config_dict = {
"train_batch_size": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
}
}

@distributed_test(world_size=1)
def _helper():
model = SimpleModel(hidden_dim=10)
with pytest.raises(AssertionError):
model, _, _, _ = deepspeed.initialize(model=None, config_params=config_dict)

with pytest.raises(AssertionError):
model, _, _, _ = deepspeed.initialize(model, config_params=config_dict)