diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index c4c2acf0b0d7..3401f121bca0 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -49,8 +49,8 @@ def _parse_version(version_str): sys.modules['deepspeed.pt.loss_scaler'] = deepspeed.runtime.fp16.loss_scaler -def initialize(args, - model, +def initialize(args=None, + model=None, optimizer=None, model_parameters=None, training_data=None, @@ -62,8 +62,7 @@ def initialize(args, """Initialize the DeepSpeed Engine. Arguments: - args: a dictionary containing local_rank and deepspeed_config - file location + args: an object containing local_rank and deepspeed_config fields. This is optional if `config_params` is passed. model: Required: nn.module class before apply any wrappers @@ -88,6 +87,9 @@ def initialize(args, mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. + config_params: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config + as a dictionary instead. + Returns: A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler`` @@ -108,6 +110,8 @@ def initialize(args, __git_branch__), ranks=[0]) + assert model is not None, "deepspeed.initialize requires a model" + if not isinstance(model, PipelineModule): engine = DeepSpeedEngine(args=args, model=model, diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f5737c07ea04..e11e2c1d7afc 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -495,9 +495,10 @@ def _configure_with_arguments(self, args, mpu): # After the distributed backend is initialized we are guaranteed the LOCAL_RANK # environment variable is set. We must align args.local_rank to this value for # backwards compatability with scripts relying on [args|self].local_rank containing - # the correct local rank info. - args.local_rank = int(os.environ['LOCAL_RANK']) - self.local_rank = args.local_rank + # the correct local rank info. _do_args_sanity_check will ensure this is the case. + self.local_rank = int(os.environ['LOCAL_RANK']) + if hasattr(args, 'local_rank'): + args.local_rank = self.local_rank config_file = args.deepspeed_config if hasattr(args, 'deepspeed_config') else None @@ -513,15 +514,14 @@ def _do_args_sanity_check(self, args): assert args.deepspeed_config is None, "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config" args.deepspeed_config = args.deepscale_config - local_rank_err = "DeepSpeed requires a command line parameter of --local_rank [int] and/or setting the LOCAL_RANK environment variable." - if hasattr(args, 'local_rank'): - assert type(args.local_rank) == int, local_rank_err - if "LOCAL_RANK" in os.environ and args.local_rank >= 0: - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + assert "LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment variable, it is set by the deepspeed launcher, " \ + "deepspeed.init_distributed, or the torch.distributed launcher. If using a different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." + if hasattr(args, 'local_rank') and args.local_rank != None: + assert isinstance(args.local_rank, int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" + if args.local_rank >= 0: + env_local_rank = int(os.environ.get("LOCAL_RANK")) assert env_local_rank == args.local_rank, \ f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}." - else: - assert "LOCAL_RANK" in os.environ, local_rank_err if self.config_params is None: assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \ diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 4cabefe71a33..7de3a40fabeb 100755 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -226,3 +226,83 @@ def _helper(): model.step() _helper() + + +def test_none_args(tmpdir): + config_dict = { + "train_batch_size": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + } + } + + @distributed_test(world_size=1) + def _helper(): + model = SimpleModel(hidden_dim=10) + model, _, _, _ = deepspeed.initialize(args=None, model=model, config_params=config_dict) + data_loader = random_dataloader(model=model, + total_samples=5, + hidden_dim=10, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + + _helper() + + +def test_no_args(tmpdir): + config_dict = { + "train_batch_size": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + } + } + + @distributed_test(world_size=1) + def _helper(): + model = SimpleModel(hidden_dim=10) + model, _, _, _ = deepspeed.initialize(model=model, config_params=config_dict) + data_loader = random_dataloader(model=model, + total_samples=5, + hidden_dim=10, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + + _helper() + + +def test_no_model(tmpdir): + config_dict = { + "train_batch_size": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + } + } + + @distributed_test(world_size=1) + def _helper(): + model = SimpleModel(hidden_dim=10) + with pytest.raises(AssertionError): + model, _, _, _ = deepspeed.initialize(model=None, config_params=config_dict) + + with pytest.raises(AssertionError): + model, _, _, _ = deepspeed.initialize(model, config_params=config_dict)