diff --git a/CHANGELOG.md b/CHANGELOG.md index cece9c30ec960..e59d48fbc8665 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -141,6 +141,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) +- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) + + ## [1.2.4] - 2021-03-16 ### Changed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 59cdb12f7cca9..ceb9d98505acc 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -65,17 +65,28 @@ def __init__( self.lr_schedulers: Sequence = [] self.optimizer_frequencies: Sequence = [] - def setup(self, trainer: 'Trainer', model: LightningModule) -> None: + def connect(self, model: LightningModule) -> None: + """Transfers ownership of the model to this plugin""" + self.training_type_plugin.connect(model) + + def setup_environment(self) -> None: """ - Connects the plugins to the training process, creates optimizers + Setup any processes or distributed connections. + This is called before the LightningModule/DataModule setup hook + which allows the user to access the accelerator environment before setup is complete. + """ + self.training_type_plugin.setup_environment() + def setup(self, trainer: 'Trainer', model: LightningModule) -> None: + """ + Setup plugins for the trainer fit and creates optimizers. Args: - trainer: the trainer instance to connect to - model: the model to train + trainer: the trainer instance + model: the LightningModule """ - self.connect_training_type_plugin(self.training_type_plugin, model) + self.setup_training_type_plugin(self.training_type_plugin, model) self.setup_optimizers(trainer) - self.connect_precision_plugin(self.precision_plugin) + self.setup_precision_plugin(self.precision_plugin) def start_training(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_training(trainer) @@ -332,14 +343,11 @@ def setup_optimizers(self, trainer: 'Trainer') -> None: self.lr_schedulers = lr_schedulers self.optimizer_frequencies = optimizer_frequencies - def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: - """Attaches the training type plugin to the accelerator. - Also transfers ownership of the model to this plugin - - """ - plugin.connect(model) + def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + """Attaches the training type plugin to the accelerator.""" + plugin.setup(model) - def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: + def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None: """Attaches the precision plugin to the accelerator""" model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers) self.model = model diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 901e9c85b162b..bcadf16607b4f 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -80,9 +80,7 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs - def setup(self, model): - self._model = model - + def setup_environment(self): # start the other scripts if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": self._call_children_scripts() @@ -90,6 +88,8 @@ def setup(self, model): # set the task idx self.task_idx = self.cluster_environment.local_rank() + self.setup_distributed() + def _call_children_scripts(self): # bookkeeping of spawned processes @@ -161,6 +161,34 @@ def _call_children_scripts(self): delay = np.random.uniform(1, 5, 1)[0] sleep(delay) + def setup_distributed(self): + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + # determine which process we are and world size + self.set_world_ranks() + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + self.init_ddp_connection(self.global_rank, self.world_size) + + # on world_size=0 let everyone know training is starting + if self.is_global_zero and not torch.distributed.is_initialized(): + log.info("-" * 100) + log.info(f"distributed_backend={self.distributed_backend}") + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") + log.info("-" * 100) + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + def _check_can_spawn_children(self): if self._has_spawned_children: raise RuntimeError( @@ -213,37 +241,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): - # TODO: check if needed - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - - # determine which process we are and world size - self.set_world_ranks() - - # set warning rank - rank_zero_only.rank = self.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - self.init_ddp_connection(self.global_rank, self.world_size) - - # TODO: we moved it to the trainer.fit after calling pre_dispatch - # ... need to double check that it is the correct place - # self.trainer.call_setup_hook(self.model) - - # on world_size=0 let everyone know training is starting - if self.is_global_zero and not torch.distributed.is_initialized(): - log.info("-" * 100) - log.info(f"distributed_backend={self.distributed_backend}") - log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") - log.info("-" * 100) - - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device - if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index d362912fa0185..ea1efd6e15873 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -77,8 +77,6 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs def setup(self, model): - self._model = model - os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) # pass in a state q diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index b54155d60eae5..b196044937414 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -192,17 +192,7 @@ def _load_config(self, config): return config def pre_dispatch(self): - self.set_world_ranks() - self.init_ddp_connection(self.global_rank, self.world_size) - self.init_deepspeed() - - # set warning rank - rank_zero_only.rank = self.global_rank - - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device self.barrier() def init_deepspeed(self): diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 283d7113795ec..d9a8e70588c43 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -53,14 +53,6 @@ def on_gpu(self): def lightning_module(self): return unwrap_lightning_module(self._model) - @abstractmethod - def setup(self, model): - raise NotImplementedError - - def connect(self, model, *args, **kwargs): - self.setup(model) - return self.model - @property def is_global_zero(self) -> bool: return self.global_rank == 0 diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 375a84da20d4b..d70779adf3ba1 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -64,8 +64,7 @@ def model_to_device(self) -> None: self._model.to(self.root_device) - def connect(self, model: torch.nn.Module) -> torch.nn.Module: - self._model = model + def setup(self, model: torch.nn.Module) -> torch.nn.Module: self.model_to_device() return self.model diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index d3cbd0d6b5d79..b8d670ff16881 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -39,13 +39,8 @@ def __init__(self, device: Union[torch.device, int]): def on_tpu(self) -> bool: return True - def connect(self, model: torch.nn.Module) -> torch.nn.Module: - self._model = model - self.model_to_device() - return self._model - def model_to_device(self) -> None: - self._model.to(self.root_device) + self.model.to(self.root_device) def pre_dispatch(self) -> None: if isinstance(self.device, int): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 22a18bf338f2a..c883ff504f24d 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -53,10 +53,9 @@ def __init__( self.tpu_local_core_rank = 0 self.start_method = None - def connect(self, model: torch.nn.Module) -> torch.nn.Module: + def setup(self, model: torch.nn.Module) -> torch.nn.Module: self.create_mp_queue() - self._model = model - return self._model + return self.model def create_mp_queue(self): self.start_method = 'fork' diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 534e189dcff02..6a87792c7bd03 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -34,9 +34,19 @@ def __init__(self) -> None: self._model = None self._results = None - @abstractmethod def connect(self, model: 'Module') -> None: - """Called by the accelerator to connect it with this plugin""" + """Called by the accelerator to connect the accelerator and the model with this plugin""" + self.model = model + + def setup_environment(self) -> None: + """ + Setup any processes or distributed connections. + This is called before the LightningModule/DataModule setup hook + which allows the user to access the accelerator environment before setup is complete. + """ + + def setup(self, model: 'Module') -> None: + """Called by the accelerator to finish setup.""" @property @abstractmethod diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 44b0e716a90c0..53b4920bd85ef 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -428,8 +428,10 @@ def fit( # ---------------------------- # SET UP TRAINING # ---------------------------- - self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) + self.accelerator.connect(model) + self.accelerator.setup_environment() + self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module # ---------------------------- diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index e6139de5d3028..79a17df074e35 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -98,7 +98,8 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): "SLURM_LOCALID": "10" } ) -def test_accelerator_choice_ddp_slurm(): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_slurm(setup_distributed_mock): class CB(Callback): @@ -136,7 +137,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=2) -def test_accelerator_choice_ddp2_slurm(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -165,7 +167,8 @@ def on_fit_start(self, trainer, pl_module): @RunIf(min_gpus=1) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) @mock.patch('torch.cuda.device_count', return_value=2) -def test_accelerator_choice_ddp_te(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_te(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -193,7 +196,8 @@ def on_fit_start(self, trainer, pl_module): @RunIf(min_gpus=1) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) @mock.patch('torch.cuda.device_count', return_value=2) -def test_accelerator_choice_ddp2_te(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp2_te(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -224,7 +228,8 @@ def on_fit_start(self, trainer, pl_module): "NODE_RANK": "0", }) @mock.patch('torch.cuda.device_count', return_value=0) -def test_accelerator_choice_ddp_cpu_te(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -259,7 +264,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_accelerator_choice_ddp_cpu_slurm(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock): class CB(Callback): @@ -294,7 +300,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock, setup_distributed_mock): """ Test that we choose the custom cluster even when SLURM or TE flags are around """ @@ -304,6 +311,9 @@ class CustomCluster(LightningEnvironment): def master_address(self): return 'asdf' + def creates_children(self) -> bool: + return True + class CB(Callback): def on_fit_start(self, trainer, pl_module): @@ -336,7 +346,8 @@ def on_fit_start(self, trainer, pl_module): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_custom_accelerator(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_custom_accelerator(device_count_mock, setup_distributed_mock): class Accel(Accelerator): pass @@ -371,7 +382,8 @@ class TrainTypePlugin(SingleDevicePlugin): } ) @mock.patch('torch.cuda.device_count', return_value=0) -def test_dist_backend_accelerator_mapping(device_count_mock): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_dist_backend_accelerator_mapping(device_count_mock, setup_distributed_mock): class CB(Callback): diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index 14e73d920af4b..541110ac8846b 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Optional +from unittest import mock from unittest.mock import patch import pytest @@ -91,7 +93,6 @@ def test_torch_distributed_backend_env_variables(tmpdir): _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} with patch.dict(os.environ, _environ), \ patch('torch.cuda.device_count', return_value=2): - with pytest.raises(ValueError, match="Invalid backend: 'undefined'"): model = BoringModel() trainer = Trainer( @@ -102,3 +103,30 @@ def test_torch_distributed_backend_env_variables(tmpdir): logger=False, ) trainer.fit(model) + + +@RunIf(skip_windows=True) +@mock.patch('torch.cuda.device_count', return_value=1) +@mock.patch('torch.cuda.is_available', return_value=True) +@mock.patch('torch.cuda.set_device') +@mock.patch.dict(os.environ, {'PL_TORCH_DISTRIBUTED_BACKEND': 'gloo'}, clear=True) +def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available, mock_device_count, tmpdir): + """ + Test to ensure torch distributed is available within the setup hook using ddp + """ + + class TestModel(BoringModel): + + def setup(self, stage: Optional[str] = None) -> None: + assert torch.distributed.is_initialized() + raise SystemExit() + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator="ddp", + gpus=1, + ) + with pytest.raises(SystemExit): + trainer.fit(model) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 608f7bf1051f6..fdefc6ae9ef1c 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -46,8 +46,8 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.setup(trainer, model, 'fit'), call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'fit'), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), @@ -115,8 +115,8 @@ def test_trainer_callback_hook_system_test(tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.setup(trainer, model, 'test'), call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'test'), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), @@ -148,8 +148,8 @@ def test_trainer_callback_hook_system_validate(tmpdir): assert callback_mock.method_calls == [ call.on_init_start(trainer), call.on_init_end(trainer), - call.setup(trainer, model, 'validate'), call.on_before_accelerator_backend_setup(trainer, model), + call.setup(trainer, model, 'validate'), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index cf5c23a824732..e6b15069f256a 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -180,7 +180,7 @@ def test_deepspeed_defaults(tmpdir): assert isinstance(plugin.config["zero_optimization"], dict) -@RunIf(deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_invalid_deepspeed_defaults_no_precision(tmpdir): """Test to ensure that using defaults, if precision is not set to 16, we throw an exception.""" model = BoringModel() @@ -195,7 +195,7 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_warn_deepspeed_override_backward(tmpdir): """Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning.""" @@ -216,7 +216,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_run_configure_optimizers(tmpdir): """Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using configure_optimizers for optimizers and schedulers.""" @@ -246,7 +246,7 @@ def on_train_start(self) -> None: _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_config(tmpdir, deepspeed_zero_config): """ Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers @@ -280,7 +280,7 @@ def on_train_start(self) -> None: _assert_save_model_is_equal(model, tmpdir, trainer) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_custom_precision_params(tmpdir): """Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these changes.""" @@ -301,7 +301,7 @@ def on_train_start(self) -> None: trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config): """Ensure if we use a config and turn off cpu_offload, that this is set to False within the config.""" diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 43658721e9226..dd67af470c4ec 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,6 +17,12 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_invalid_deepspeed_defaults_no_precision +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_custom_precision_params +python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_assert_config_zero_offload_disabled python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual