Skip to content

Commit

Permalink
[SPARK-44264][ML][PYTHON] Write a Deepspeed Distributed Learning Clas…
Browse files Browse the repository at this point in the history
…s DeepspeedTorchDistributor

### What changes were proposed in this pull request?
Implemented a distributed learning class meant for deepspeed workloads using the torch.distributed.run command. Also made some tests for some of the functions. Need to add tests for the distributed workloads in the create_torchrun_command.

### Why are the changes needed?
Special commands are needed for deepspeed workloads. This class makes it easier to run the deepspeed applications without ever needing to touch the terminal. If a user needs to use the torch.distributed.run launcher, this class will let them do that. This class also has a very similar API and workflow to the TorchDistributor class, where you simply create an instance and then invoke distributor.run(...).

# Checklist

- [x] creates deepspeed command
- [x] can support running a python file in a distributed fashion
- [ ] supports distributed training with a function (to be implemented in future PR)
- [ ] cleans up any temporary files it made (future PR)
- [x] unit tests

Closes #41770 from mathewjacob1002/deepspeed.

Lead-authored-by: Mathew Jacob <[email protected]>
Co-authored-by: Mathew Jacob <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
2 people authored and HyukjinKwon committed Jul 11, 2023
1 parent 37aa62f commit 0d90f2a
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 20 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ def __hash__(self):
"pyspark.ml.torch.tests.test_distributor",
"pyspark.ml.torch.tests.test_log_communication",
"pyspark.ml.torch.tests.test_data_loader",
"pyspark.ml.deepspeed.tests.test_deepspeed_distributor",
"pyspark.ml.tests.connect.test_legacy_mode_summarizer",
"pyspark.ml.tests.connect.test_legacy_mode_evaluation",
"pyspark.ml.tests.connect.test_legacy_mode_feature",
Expand Down
157 changes: 157 additions & 0 deletions python/pyspark/ml/deepspeed/deepspeed_distributor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import sys
import tempfile
from typing import (
Union,
Callable,
List,
Dict,
Optional,
Any,
)

from pyspark.ml.torch.distributor import TorchDistributor


class DeepspeedTorchDistributor(TorchDistributor):

_DEEPSPEED_SSL_CONF = "deepspeed.spark.distributor.ignoreSsl"

def __init__(
self,
num_gpus: int = 1,
nnodes: int = 1,
local_mode: bool = True,
use_gpu: bool = True,
deepspeed_config: Optional[Union[str, Dict[str, Any]]] = None,
):
"""
This class is used to run deepspeed training workloads with spark clusters.
The user has the option to specify the number of gpus per node
and the number of nodes (the same as if running from terminal),
as well as specify a deepspeed configuration file.
Parameters
----------
num_gpus: int
The number of GPUs to use per node (analagous to num_gpus in deepspeed command).
nnodes: int
The number of nodes that should be used for the run.
local_mode: bool
Whether or not to run the training in a distributed fashion or just locally.
use_gpu: bool
Boolean flag to determine whether to utilize gpus.
deepspeed_config: Union[Dict[str,Any], str] or None:
The configuration file to be used for launching the deepspeed application.
If it's a dictionary containing the parameters, then we will create the file.
If None, deepspeed will fall back to default parameters.
"""
num_processes = num_gpus * nnodes
self.deepspeed_config = deepspeed_config
super().__init__(
num_processes,
local_mode,
use_gpu,
_ssl_conf=DeepspeedTorchDistributor._DEEPSPEED_SSL_CONF,
)
self.cleanup_deepspeed_conf = False

@staticmethod
def _get_deepspeed_config_path(deepspeed_config: Union[str, Dict[str, Any]]) -> str:
if isinstance(deepspeed_config, dict):
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as file:
json.dump(deepspeed_config, file)
return file.name
deepspeed_config_path = deepspeed_config
# Empty value means the deepspeed will fall back to default settings.
if deepspeed_config is None:
return ""
return deepspeed_config_path

@staticmethod
def _create_torchrun_command(
input_params: Dict[str, Any], train_path: str, *args: Any
) -> List[str]:
local_mode = input_params["local_mode"]
num_processes = input_params["num_processes"]
deepspeed_config = input_params["deepspeed_config"]
deepspeed_config_path = DeepspeedTorchDistributor._get_deepspeed_config_path(
deepspeed_config
)
torchrun_args, processes_per_node = TorchDistributor._get_torchrun_args(
local_mode, num_processes
)
args_string = list(map(str, args))
command_to_run = [
sys.executable,
"-m",
"torch.distributed.run",
*torchrun_args,
f"--nproc_per_node={processes_per_node}",
train_path,
*args_string,
"-deepspeed",
]

# Don't have the deepspeed_config argument if no path is provided or no parameters set
if deepspeed_config_path == "":
return command_to_run
return command_to_run + ["--deepspeed_config", deepspeed_config_path]

@staticmethod
def _run_training_on_pytorch_file(
input_params: Dict[str, Any], train_path: str, *args: Any, **kwargs: Any
) -> None:
if kwargs:
raise ValueError(
"DeepspeedTorchDistributor with pytorch file doesn't support keyword arguments"
)

log_streaming_client = input_params.get("log_streaming_client", None)
training_command = DeepspeedTorchDistributor._create_torchrun_command(
input_params, train_path, *args
)
DeepspeedTorchDistributor._execute_command(
training_command, log_streaming_client=log_streaming_client
)

def run(self, train_object: Union[Callable, str], *args: Any, **kwargs: Any) -> Optional[Any]:
# If the "train_object" is a string, then we assume it's a filepath.
# Otherwise, we assume it's a function.
if isinstance(train_object, str):
if os.path.exists(train_object) is False:
raise FileNotFoundError(f"The path to training file {train_object} does not exist.")
framework_wrapper_fn = DeepspeedTorchDistributor._run_training_on_pytorch_file
else:
raise RuntimeError("Python training functions aren't supported as inputs at this time")

if self.local_mode:
return self._run_local_training(framework_wrapper_fn, train_object, *args, **kwargs)
return self._run_distributed_training(
framework_wrapper_fn,
train_object,
spark_dataframe=None,
*args,
**kwargs, # type:ignore[misc]
)
176 changes: 176 additions & 0 deletions python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
from typing import Any, Tuple, Dict
import unittest

from pyspark.ml.deepspeed.deepspeed_distributor import DeepspeedTorchDistributor


class DeepspeedTorchDistributorUnitTests(unittest.TestCase):
def _get_env_var(self, var_name: str, default_value: Any) -> Any:
value = os.getenv(var_name)
if value:
return value
os.environ[var_name] = str(default_value)
return default_value

def _get_env_variables_distributed(self) -> Tuple[Any, Any, Any]:
MASTER_ADDR = self._get_env_var("MASTER_ADDR", "127.0.0.1")
MASTER_PORT = self._get_env_var("MASTER_PORT", 2000)
RANK = self._get_env_var("RANK", 0)
return MASTER_ADDR, MASTER_PORT, RANK

def test_get_torchrun_args_local(self) -> None:
number_of_processes = 5
EXPECTED_TORCHRUN_ARGS_LOCAL = ["--standalone", "--nnodes=1"]
EXPECTED_PROCESSES_PER_NODE_LOCAL = number_of_processes
(
get_local_mode_torchrun_args,
process_per_node,
) = DeepspeedTorchDistributor._get_torchrun_args(True, number_of_processes)
self.assertEqual(get_local_mode_torchrun_args, EXPECTED_TORCHRUN_ARGS_LOCAL)
self.assertEqual(EXPECTED_PROCESSES_PER_NODE_LOCAL, process_per_node)

def test_get_torchrun_args_distributed(self) -> None:
number_of_processes = 5
MASTER_ADDR, MASTER_PORT, RANK = self._get_env_variables_distributed()
EXPECTED_TORCHRUN_ARGS_DISTRIBUTED = [
f"--nnodes={number_of_processes}",
f"--node_rank={RANK}",
f"--rdzv_endpoint={MASTER_ADDR}:{MASTER_PORT}",
"--rdzv_id=0",
]
torchrun_args_distributed, process_per_node = DeepspeedTorchDistributor._get_torchrun_args(
False, number_of_processes
)
self.assertEqual(torchrun_args_distributed, EXPECTED_TORCHRUN_ARGS_DISTRIBUTED)
self.assertEqual(process_per_node, 1)

def test_create_torchrun_command_local(self) -> None:
DEEPSPEED_CONF = "path/to/deepspeed"
TRAIN_FILE_PATH = "path/to/exec"
NUM_PROCS = 10
input_params: Dict[str, Any] = {}
input_params["local_mode"] = True
input_params["num_processes"] = NUM_PROCS
input_params["deepspeed_config"] = DEEPSPEED_CONF

torchrun_local_args_expected = ["--standalone", "--nnodes=1"]
with self.subTest(msg="Testing local training with no extra args"):
LOCAL_CMD_NO_ARGS_EXPECTED = [
sys.executable,
"-m",
"torch.distributed.run",
*torchrun_local_args_expected,
f"--nproc_per_node={NUM_PROCS}",
TRAIN_FILE_PATH,
"-deepspeed",
"--deepspeed_config",
DEEPSPEED_CONF,
]
local_cmd = DeepspeedTorchDistributor._create_torchrun_command(
input_params, TRAIN_FILE_PATH
)
self.assertEqual(local_cmd, LOCAL_CMD_NO_ARGS_EXPECTED)
with self.subTest(msg="Testing local training with extra args for the training script"):
local_mode_version_args = ["--arg1", "--arg2"]
LOCAL_CMD_ARGS_EXPECTED = [
sys.executable,
"-m",
"torch.distributed.run",
*torchrun_local_args_expected,
f"--nproc_per_node={NUM_PROCS}",
TRAIN_FILE_PATH,
*local_mode_version_args,
"-deepspeed",
"--deepspeed_config",
DEEPSPEED_CONF,
]

local_cmd_with_args = DeepspeedTorchDistributor._create_torchrun_command(
input_params, TRAIN_FILE_PATH, *local_mode_version_args
)
self.assertEqual(local_cmd_with_args, LOCAL_CMD_ARGS_EXPECTED)

def test_create_torchrun_command_distributed(self) -> None:
DEEPSPEED_CONF = "path/to/deepspeed"
TRAIN_FILE_PATH = "path/to/exec"
NUM_PROCS = 10
input_params: Dict[str, Any] = {}
input_params["local_mode"] = True
input_params["num_processes"] = NUM_PROCS
input_params["deepspeed_config"] = DEEPSPEED_CONF
(
distributed_master_address,
distributed_master_port,
distributed_rank,
) = self._get_env_variables_distributed()
distributed_torchrun_args = [
f"--nnodes={NUM_PROCS}",
f"--node_rank={distributed_rank}",
f"--rdzv_endpoint={distributed_master_address}:{distributed_master_port}",
"--rdzv_id=0",
]
with self.subTest(msg="Distributed training command verification with no extra args"):
DISTRIBUTED_CMD_NO_ARGS_EXPECTED = [
sys.executable,
"-m",
"torch.distributed.run",
*distributed_torchrun_args,
"--nproc_per_node=1",
TRAIN_FILE_PATH,
"-deepspeed",
"--deepspeed_config",
DEEPSPEED_CONF,
]
input_params["local_mode"] = False
distributed_command = DeepspeedTorchDistributor._create_torchrun_command(
input_params, TRAIN_FILE_PATH
)
self.assertEqual(DISTRIBUTED_CMD_NO_ARGS_EXPECTED, distributed_command)
with self.subTest(msg="Distributed training command verification with extra arguments"):
distributed_extra_args = ["-args1", "--args2"]
DISTRIBUTED_CMD_ARGS_EXPECTED = [
sys.executable,
"-m",
"torch.distributed.run",
*distributed_torchrun_args,
"--nproc_per_node=1",
TRAIN_FILE_PATH,
*distributed_extra_args,
"-deepspeed",
"--deepspeed_config",
DEEPSPEED_CONF,
]
distributed_command_with_args = DeepspeedTorchDistributor._create_torchrun_command(
input_params, TRAIN_FILE_PATH, *distributed_extra_args
)
self.assertEqual(DISTRIBUTED_CMD_ARGS_EXPECTED, distributed_command_with_args)


if __name__ == "__main__":
from pyspark.ml.deepspeed.tests.test_deepspeed_distributor import * # noqa: F401,F403

try:
import xmlrunner # type:ignore

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading

0 comments on commit 0d90f2a

Please sign in to comment.