Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 1-bit Adam support to DeepSpeed #380

Merged
merged 36 commits into from
Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
fa66867
1-bit adam (#353)
awan-10 Sep 8, 2020
45619af
updated launcher to support pdsh and openmpi (#375)
jeffra Sep 8, 2020
6ae0437
revert timer and update args
jeffra Sep 8, 2020
036a96c
add mvapich to launcher
jeffra Sep 8, 2020
dce677e
update parse args
jeffra Sep 8, 2020
4fa6e05
convert ordered dict to list before indexing
jeffra Sep 8, 2020
2fdb79d
cleanup
jeffra Sep 8, 2020
1d6d5f5
update default launcher envs to include mpi vars
jeffra Sep 8, 2020
b3894e8
Merge branch 'master' into staging-1bit-adam-v2
jeffra Sep 8, 2020
65564e5
address comments on launcher
jeffra Sep 9, 2020
402dccd
Merge branch 'staging-1bit-adam-v2' of github.com:microsoft/DeepSpeed…
jeffra Sep 9, 2020
0bab578
Add a new test.
awan-10 Sep 9, 2020
f358e6a
Merge branch 'staging-1bit-adam-v2' of github.com:microsoft/DeepSpeed…
awan-10 Sep 9, 2020
cae1acc
Crucial fix in testing code.
awan-10 Sep 9, 2020
efe71b4
Fix formatting.
awan-10 Sep 9, 2020
1a124c4
Merge branch 'staging-1bit-adam-v2' of github.com:microsoft/DeepSpeed…
awan-10 Sep 9, 2020
5ecdcf0
Fix formatting.
awan-10 Sep 9, 2020
847447b
Seperate host and cuda tests for ease of use.
awan-10 Sep 9, 2020
cdc1818
Merge branch 'master' into staging-1bit-adam-v2
jeffra Sep 9, 2020
03e99db
update copyright
jeffra Sep 9, 2020
769d9b5
Merge branch 'staging-1bit-adam-v2' of github.com:microsoft/DeepSpeed…
jeffra Sep 9, 2020
ba787c2
Significantly enhance the tutorial for 1-bit Adam.
awan-10 Sep 9, 2020
9cb5134
Merge branch 'staging-1bit-adam-v2' of github.com:microsoft/DeepSpeed…
awan-10 Sep 9, 2020
f623517
Fix typo.Fix typo.
awan-10 Sep 9, 2020
eb47cdf
install openmpi in conda
jeffra Sep 9, 2020
0ac97ea
Merge branch 'staging-1bit-adam-v2' of github.com:microsoft/DeepSpeed…
jeffra Sep 9, 2020
2b9d9c8
Update azure-pipelines.yml
jeffra Sep 9, 2020
5e33b08
update basic install test
jeffra Sep 9, 2020
8adb63d
small update to getting started docs
jeffra Sep 9, 2020
a76d2a2
Merge branch 'staging-1bit-adam-v2' of github.com:microsoft/DeepSpeed…
jeffra Sep 9, 2020
6693159
formatting
jeffra Sep 9, 2020
e944d22
output env
jeffra Sep 9, 2020
ded0475
Fix initializer.
awan-10 Sep 9, 2020
07fe582
Update setup.py
jeffra Sep 9, 2020
15d2abe
Update azure-pipelines.yml
jeffra Sep 9, 2020
6df3dbb
Merge branch 'master' into staging-1bit-adam-v2
jeffra Sep 9, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
conda install -q --yes pip
conda install -q --yes gxx_linux-64
if [[ $(cuda.version) != "10.2" ]]; then conda install --yes -c conda-forge cudatoolkit-dev=$(cuda.version) ; fi
echo "PATH=$PATH, LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
displayName: 'Setup environment python=$(python.version) pytorch=$(pytorch.version) cuda=$(cuda.version)'

# Manually install torch/torchvision first to enforce versioning.
Expand Down
35 changes: 22 additions & 13 deletions basic_install_test.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,46 @@
import torch
import warnings
import importlib

try:
import deepspeed as ds
import deepspeed
print("deepspeed successfully imported")
except ImportError as err:
raise err

print(f"torch install path: {torch.__path__}")
print(f"torch version: {torch.__version__}")

print(f"deepspeed info: {ds.__version__}, {ds.__git_hash__}, {ds.__git_branch__}")
print(f"deepspeed install path: {deepspeed.__path__}")
print(
f"deepspeed info: {deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
)

try:
apex_C = importlib.import_module('apex_C')
print("apex successfully installed")
except Exception as err:
raise err

try:
fused_lamb = importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda')
print('deepspeed fused lamb kernels successfully installed')
except Exception as err:
raise err

try:
from apex.optimizers import FP16_Optimizer
print("using old-style apex")
except ImportError:
print("using new-style apex")

try:
ds_transformer = importlib.import_module(
'deepspeed.ops.transformer.transformer_cuda')
print('deepspeed transformer kernels successfully installed')
importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda')
print('deepspeed lamb successfully installed.')
except Exception as err:
raise err
warnings.warn("deepspeed lamb is NOT installed.")

try:
importlib.import_module('deepspeed.ops.transformer.transformer_cuda')
print('deepspeed transformer kernels successfully installed.')
except Exception as err:
warnings.warn('deepspeed transformer kernels are NOT installed.')

try:
importlib.import_module('deepspeed.ops.sparse_attention.cpp_utils')
print('deepspeed sparse attention successfully installed.')
except ImportError:
warnings.warn('deepspeed sparse attention is NOT installed.')
14 changes: 14 additions & 0 deletions deepspeed/launcher/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2020 The Microsoft DeepSpeed Team

#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500

PDSH_LAUNCHER = 'pdsh'
PDSH_MAX_FAN_OUT = 1024

OPENMPI_LAUNCHER = 'openmpi'

MVAPICH_LAUNCHER = 'mvapich'
MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile'
13 changes: 10 additions & 3 deletions deepspeed/launcher/launch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Copyright 2020 The Microsoft DeepSpeed Team
"""
Copyright 2020 The Microsoft DeepSpeed Team: [email protected]
DeepSpeed launcher, this is similar to torch.distributed.launch but supports
additional features such as abitrary gpu exclusion.

deepspeed.launcher.launch is intended to be run on a single worker node and
will spawn several worker sub-processes depending on how many devices/ranks
are on the worker.
"""

import sys
Expand All @@ -10,7 +16,8 @@
from collections import defaultdict
from argparse import ArgumentParser, REMAINDER

from deepspeed.utils import logger
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..utils import logger


def parse_args():
Expand All @@ -32,7 +39,7 @@ def parse_args():
" single node multi-proc training, the"
" --master_addr can simply be 127.0.0.1")
parser.add_argument("--master_port",
default=29500,
default=TORCH_DISTRIBUTED_DEFAULT_PORT,
type=int,
help="Master node (rank 0)'s free port that needs to "
"be used for communication during distributed "
Expand Down
189 changes: 189 additions & 0 deletions deepspeed/launcher/multinode_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import os
import sys
import shutil
import subprocess
import warnings
from abc import ABC, abstractmethod

from ..utils import logger
from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE


class MultiNodeRunner(ABC):
def __init__(self, args, world_info_base64):
self.args = args
self.user_arguments = self.parse_user_args()
self.user_script = args.user_script
self.world_info_base64 = world_info_base64
self.exports = {}

@abstractmethod
def backend_exists(self):
pass

@abstractmethod
def get_cmd(self, environment, active_resources):
pass

def add_export(self, key, var):
self.exports[key.strip()] = var.strip()

def parse_user_args(self):
return self.args.user_args


class PDSHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64):
super().__init__(args, world_info_base64)

def backend_exists(self):
return shutil.which('pdsh')

def parse_user_args(self):
return list(
map(lambda x: x if x.startswith("-") else "'{}'".format(x),
self.args.user_args))

def get_cmd(self, environment, active_resources):
environment['PDSH_RCMD_TYPE'] = 'ssh'

active_workers = ",".join(active_resources.keys())
logger.info("Running on the following workers: %s" % active_workers)

# PDSH flags for max node fan out and specific hosts to launch on
# See https://linux.die.net/man/1/pdsh for flag details
pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers]

exports = ""
for key, val in self.exports.items():
exports += "export {}={}; ".format(key, val)

deepspeed_launch = [
exports,
"cd {};".format(os.path.abspath('.')),
sys.executable,
"-u",
"-m",
"deepspeed.launcher.launch",
'--world_info={}'.format(self.world_info_base64),
"--node_rank=%n",
"--master_addr={}".format(self.args.master_addr),
"--master_port={}".format(self.args.master_port)
]

return pdsh_cmd_args + deepspeed_launch + [self.user_script
] + self.user_arguments


class OpenMPIRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
self.add_export('UCX_TLS', 'tcp')

def backend_exists(self):
#TODO: if IB is available we should suggestion mvapich
return shutil.which('ompi_info')

def get_cmd(self, environment, active_resources):
#TODO: Allow for include/exclude at node-level but not gpu-level
assert self.args.include == "" and self.args.exclude == "", 'openmpi backend does not support worker include/exclusion'
assert self.args.num_nodes == -1 and self.args.num_gpus == -1, 'openmpi backend does not support limiting num nodes/gpus'
total_process_count = sum(self.resource_pool.values())

mpirun_cmd = [
'mpirun',
'-n',
f'{total_process_count}',
'-hostfile',
f'{self.args.hostfile}',
'--mca',
'btl',
'^openib',
'--mca',
'btl_tcp_if_include',
'eth0',
]

export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', f'{k}={v}']

python_exec = [sys.executable, "-u"]

return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments


class MVAPICHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool

# Disable the CMA kernel module, not available on Ubuntu systems
self.add_export('MV2_SMP_USE_CMA', '0')

# If we fail this will output more verbose logging
self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1')

# Enabled cuda-aware communication
self.add_export('MV2_USE_CUDA', '1')

# Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/
self.add_export('MV2_SUPPORT_DL', '1')

# Support MPI_THREAD_MULTIPLE
self.add_export('MV2_ENABLE_AFFINITY', '0')

# Performance tuning flags for allgather
self.add_export('MV2_INTER_ALLGATHER_TUNING', '5')
self.add_export('MV2_CUDA_USE_NAIVE', '0')

def backend_exists(self):
#TODO: if IB is available we should suggestion mvapich
mpiname_exists = shutil.which('mpiname')
exists = False
if not mpiname_exists:
warnings.warn("mpiname does not exist, mvapich is not installed properly")
else:
results = subprocess.check_output('mpiname', shell=True)
mpiname_results = results.decode('utf-8').strip()
if "MVAPICH2-GDR" in mpiname_results:
exists = True
else:
warnings.warn(
f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}"
)
return exists

def get_cmd(self, environment, active_resources):
#TODO: Allow for include/exclude at node-level but not gpu-level
assert self.args.include == "" and self.args.exclude == "", 'mvapich backend does not support worker include/exclusion'
assert self.args.num_nodes == -1 and self.args.num_gpus == -1, 'mvapich backend does not support limiting num nodes/gpus'
devices_per_node = self.resource_pool.values()
total_process_count = sum(devices_per_node)
process_per_node = list(devices_per_node)[0]
assert all([n == process_per_node for n in devices_per_node]), "mvapich requires same number of devices per node"

with open(MVAPICH_TMP_HOSTFILE, 'w') as fd:
for host in self.resource_pool.keys():
fd.write(f'{host}\n')

mpirun_cmd = [
'mpirun',
'-np',
f'{total_process_count}',
'-ppn',
f'{process_per_node}',
'--hostfile',
f'{MVAPICH_TMP_HOSTFILE}',
]

export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-env', f'{k}={v}']

python_exec = [sys.executable, "-u"]

return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
Loading