Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix Trainer.test in ddp before running Trainer.fit #2790

Closed
wants to merge 198 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
198 commits
Select commit Hold shift + click to select a range
ca98878
do not force
Aug 1, 2020
b0dbc28
debug
Aug 3, 2020
c9f91e0
debug
Aug 3, 2020
e64a56f
debug
Aug 3, 2020
602809f
debug
Aug 3, 2020
fb2e0c8
debug
Aug 3, 2020
f238dc3
debug
Aug 3, 2020
81c5255
debug
Aug 3, 2020
d48e147
debug
Aug 3, 2020
885c1d7
debug
Aug 3, 2020
c1da18b
debug
Aug 3, 2020
98cb6bb
debug
Aug 3, 2020
87e4a78
debug
Aug 3, 2020
8279449
debug
Aug 3, 2020
3fbdf76
debug
Aug 3, 2020
34ac16b
debug
Aug 3, 2020
24fb056
debug
Aug 3, 2020
8d9c49e
Merge branch 'master' into bugfix/test-before-fit
awaelchli Aug 3, 2020
9fce421
merge
awaelchli Aug 3, 2020
7b42a0f
debug
Aug 3, 2020
223136b
debug
Aug 3, 2020
463dfbb
debug
Aug 3, 2020
8395e14
debug
Aug 3, 2020
3453ee2
debug
Aug 4, 2020
69804b7
debug
Aug 4, 2020
60241b5
debug
Aug 4, 2020
700d881
debug
Aug 4, 2020
9928148
debug
Aug 4, 2020
f3c4404
debug
Aug 4, 2020
d95cc46
debug
Aug 4, 2020
50ab31e
debug
Aug 4, 2020
43f2d65
debug
Aug 4, 2020
752dbf1
debug
Aug 4, 2020
fc15ea7
debug
Aug 4, 2020
61e90f5
debug
Aug 4, 2020
bf30a98
debug
Aug 4, 2020
703c1c9
debug
Aug 4, 2020
414e6cc
debug
Aug 4, 2020
3a75faf
debug
Aug 4, 2020
6d8cd81
debug
Aug 4, 2020
85f8929
debug
Aug 4, 2020
f3bb93d
debug
Aug 4, 2020
53a7338
debug
Aug 4, 2020
7a35761
debug
Aug 4, 2020
79358fc
debug
Aug 4, 2020
f106dfb
debug
Aug 4, 2020
3d01604
debug
Aug 4, 2020
f97a8ed
debug
Aug 4, 2020
b426258
debug
Aug 4, 2020
cf09642
debug
Aug 4, 2020
138c906
debug
Aug 4, 2020
b3665d7
debug
Aug 4, 2020
47c4800
debug
Aug 4, 2020
6a9750f
debug
Aug 4, 2020
4e39510
ddptest
Aug 6, 2020
7d82e6b
ddptest
Aug 6, 2020
6c4e4c9
ddptest
Aug 6, 2020
111633d
ddptest
Aug 6, 2020
87ee614
ddptest
Aug 6, 2020
1a26952
ddptest
Aug 6, 2020
6354b21
ddptest
Aug 6, 2020
e7b6ea4
ddptest
Aug 6, 2020
ab94100
ddptest
Aug 6, 2020
18e47c7
ddptest
Aug 6, 2020
d396f7f
ddptest
Aug 6, 2020
5024dcc
ddptest
Aug 6, 2020
26d49c8
ddptest
Aug 6, 2020
bd8f762
ddptest
Aug 6, 2020
f3fe1bc
ddptest
Aug 6, 2020
e4d1823
ddptest
Aug 6, 2020
924b26a
ddptest
Aug 6, 2020
38b89d8
ddptest
Aug 6, 2020
6bd3cec
ddptest
Aug 6, 2020
4431213
ddptest
Aug 6, 2020
28ab5cd
ddptest
Aug 6, 2020
dc16a1f
add ddp script variations
Aug 6, 2020
9031558
add ddp test
Aug 6, 2020
b5bc4d6
rename
Aug 7, 2020
13fc64a
shell
Aug 7, 2020
3163db8
test
Aug 7, 2020
bd189a9
test
Aug 7, 2020
ce4274f
try call
Aug 7, 2020
886ce19
try without subprocess
Aug 7, 2020
884e759
test
Aug 7, 2020
65c1cff
display the error
Aug 7, 2020
d6c57eb
list all variations
awaelchli Aug 8, 2020
3be75ba
try string
awaelchli Aug 9, 2020
25a2748
try copy env
Aug 9, 2020
0911f31
debug
Aug 9, 2020
e700f81
pythonpath
Aug 9, 2020
83bd213
path
Aug 9, 2020
1cecde9
update test
Aug 9, 2020
1316c55
change
Aug 9, 2020
30ad2e7
Merge branch 'ddp_testing' into bugfix/test-before-fit
Aug 9, 2020
61a80ec
remove old file
Aug 9, 2020
462776b
debug
Aug 9, 2020
764c06a
try new
Aug 9, 2020
69fe561
port
Aug 9, 2020
844f106
debug
Aug 9, 2020
a44b9e3
debug
Aug 9, 2020
e712eb9
debug
Aug 9, 2020
5c21884
debug
Aug 10, 2020
2fe51fa
debug
Aug 10, 2020
59c0173
debug
Aug 10, 2020
f3d0190
debug
Aug 10, 2020
5c06679
debug
Aug 10, 2020
5ba3962
debug
Aug 10, 2020
e74cb9c
debug
Aug 10, 2020
a7c732d
debug
Aug 10, 2020
fa5d177
debug
Aug 10, 2020
01a8f11
debug
Aug 10, 2020
3ac5609
debug
Aug 10, 2020
0531f11
debug
Aug 10, 2020
2431333
debug
Aug 10, 2020
7b40fc0
debug
Aug 10, 2020
a293da0
debug
Aug 10, 2020
ee393bd
debug
Aug 10, 2020
a4c546a
debug
Aug 10, 2020
ba517bd
debug
Aug 10, 2020
308ed14
debug
Aug 10, 2020
9f34b2c
debug
Aug 10, 2020
49ed09d
debug
Aug 10, 2020
1874b8a
debug
Aug 10, 2020
b22bd74
debug
Aug 10, 2020
46915c6
cleanup
Aug 10, 2020
c3f9c86
cleanup
Aug 10, 2020
454d4cf
cleanup
Aug 10, 2020
27a815f
move class
Aug 10, 2020
748a963
cleanup
Aug 10, 2020
ce2f31e
cleanup
Aug 10, 2020
76fe75b
cleanup
Aug 10, 2020
6c45ebc
cleanup
Aug 10, 2020
0530234
cleanup
Aug 10, 2020
fe59656
cleanup
Aug 10, 2020
cbab095
cleanup
Aug 10, 2020
9c3dde5
cleanup
Aug 10, 2020
02a5070
cleanup
Aug 10, 2020
0c5592c
cleanup
Aug 10, 2020
f1c5edc
cleanup
Aug 10, 2020
59c95ac
cleanup
Aug 10, 2020
75d4085
Merge branch 'master' into bugfix/test-before-fit
Aug 10, 2020
ed4058f
merge
Aug 10, 2020
1aa0591
cleanup
Aug 10, 2020
c81138f
cleanup
Aug 10, 2020
528381b
try atexit handler
Aug 10, 2020
a0dca5b
cleanup
Aug 10, 2020
7a16c32
cleanup
Aug 10, 2020
473f004
add note about teardown
Aug 10, 2020
c7365fd
cleanup
Aug 10, 2020
d432f56
cleanup
Aug 10, 2020
dbac944
cleanup
Aug 10, 2020
ce1de36
cleanup
Aug 10, 2020
f6dfab9
cleanup
Aug 10, 2020
c527ab5
cleanup
Aug 10, 2020
48263a8
repair
Aug 10, 2020
f393d46
repair
Aug 10, 2020
d8b7d66
repair
Aug 10, 2020
569fe0e
repair
Aug 10, 2020
3d66bac
repair
Aug 11, 2020
ce59c5f
repair
Aug 11, 2020
e53dbe0
repair
Aug 11, 2020
cab2245
repair
Aug 11, 2020
f7fb55d
repair
Aug 11, 2020
d9bd460
repair
Aug 11, 2020
d6fd24c
repair
Aug 11, 2020
4bf3706
repair
Aug 11, 2020
72edd6a
debug
Aug 11, 2020
d128cd5
repair
Aug 11, 2020
795de43
repair
Aug 11, 2020
4c0550a
repair
Aug 11, 2020
a2c47b1
repair
Aug 11, 2020
ae201e8
repair
Aug 11, 2020
ce90830
repair
Aug 11, 2020
99fd9f6
repair
Aug 11, 2020
e5ff21f
repair
Aug 11, 2020
9e6b892
repair
Aug 11, 2020
ab7ebdd
repair
Aug 11, 2020
47712a0
repair
Aug 11, 2020
68a2db6
repair
Aug 11, 2020
8f8c0fd
repair
Aug 11, 2020
159b4c8
repair
Aug 11, 2020
ce4ad1e
repair
Aug 11, 2020
ce8a93c
repair
Aug 11, 2020
25767df
repair
Aug 11, 2020
418fc90
repair
Aug 11, 2020
0495da8
repair
Aug 11, 2020
5b267ff
repair
Aug 11, 2020
18e75ca
repair
Aug 11, 2020
6d56a78
repair
Aug 11, 2020
8622c43
repair
Aug 11, 2020
6dfec2c
repair
Aug 11, 2020
b35679c
repair
Aug 11, 2020
d0e6f3b
repair
Aug 11, 2020
13e9236
repair
Aug 11, 2020
f684550
repair
Aug 11, 2020
b5f8978
repair
Aug 11, 2020
f9a7353
simple
Aug 15, 2020
68ec750
mem
Aug 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 129 additions & 23 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import atexit
import os
import socket

import torch
import torch.distributed
import subprocess
import sys
from os.path import abspath
from time import sleep
from typing import Optional

import numpy as np
import torch
from os.path import abspath

from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_debug
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from typing import Optional

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -37,14 +39,17 @@
try:
from apex import amp
except ImportError:
amp = None
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True


class DDPBackend(object):

def __init__(self, trainer):
self.trainer = trainer
self.task_idx = None
self.distributed_connection = DistributedConnection(trainer)

def slurm_setup(self):
self.task_idx = int(os.environ['SLURM_LOCALID'])
Expand All @@ -56,19 +61,15 @@ def train(self, model):
self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)

def spawn_ddp_children(self, model):
port = os.environ['MASTER_PORT']
assert self.trainer.global_rank == 0

master_address = '127.0.0.1' if 'MASTER_ADDR' not in os.environ else os.environ['MASTER_ADDR']
os.environ['MASTER_PORT'] = f'{port}'
master_address = os.environ.get('MASTER_ADDR', '127.0.0.1')
os.environ['MASTER_ADDR'] = f'{master_address}'

# allow the user to pass the node rank
node_rank = '0'
if 'NODE_RANK' in os.environ:
node_rank = os.environ['NODE_RANK']
if 'GROUP_RANK' in os.environ:
node_rank = os.environ['GROUP_RANK']

node_rank = os.environ.get('NODE_RANK', node_rank)
node_rank = os.environ.get('GROUP_RANK', node_rank)
os.environ['NODE_RANK'] = node_rank
os.environ['LOCAL_RANK'] = '0'

Expand Down Expand Up @@ -153,11 +154,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

self.distributed_connection.reset_connection(self.trainer, model)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)
Expand All @@ -176,6 +174,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

print('here 1')

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)
Expand All @@ -193,15 +193,20 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
gpu_idx = int(available_gpus[self.trainer.local_rank])

print('here 2')
self.trainer.root_gpu = gpu_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

print('here 3')

# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)

# AMP - run through amp wrapper before going to distributed DP
if self.trainer.amp_type == AMPType.APEX:
# AMP
# run through amp wrapper before going to distributed DP
# TODO: remove with dropping NVIDIA AMP support
if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
Expand All @@ -212,12 +217,18 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
else: # includes ddp_cpu
device_ids = None

print('here 4')

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

print('here 5')

# continue training routine
results = self.trainer.run_pretrain_routine(model)

print('here 6')

# get original model
model = self.trainer.get_model()

Expand All @@ -229,3 +240,98 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0

if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
return results


class DistributedConnection:

def __init__(self, trainer):
super().__init__()
self.trainer = trainer
if trainer.num_nodes == 1:
# select or forcibly set an initial port before ddp connection is initialized
self._set_master_port(port=self._get_master_port())

def reset_connection(self, trainer, model):
if not torch.distributed.is_initialized():
print('init ddp', 'rank', trainer.global_rank, 'port', self._get_master_port())
model.init_ddp_connection(trainer.global_rank, trainer.world_size, trainer.is_slurm_managing_tasks)

def reset_connection_old(self, trainer, model):

if not torch.distributed.is_initialized():
print('init ddp', 'rank', trainer.global_rank, 'port', self._get_master_port())
model.init_ddp_connection(trainer.global_rank, trainer.world_size, trainer.is_slurm_managing_tasks)
print('init ddp', 'rank', trainer.global_rank, 'port', self._get_master_port(), 'done')

new_port = torch.tensor([int(self._get_master_port())], dtype=torch.int, device='cuda')
if torch.distributed.is_initialized() and trainer.global_rank == 0:
print(trainer.global_rank, "DDP connection already initialized. Reinitializing on new port...")

#model.init_ddp_connection(trainer.global_rank, trainer.world_size, trainer.is_slurm_managing_tasks)

# torch.distributed.barrier()


#if trainer.global_rank == 0:
port = find_open_network_port()
new_port[0] = port

torch.distributed.broadcast(new_port, src=0)
new_port = int(new_port.item())
print('recv new port', 'rank', trainer.global_rank, 'port', new_port)

if int(self._get_master_port()) != new_port:
print('need to update port')
torch.distributed.destroy_process_group() # destroy connections on old port
print('destroy group', 'rank', trainer.global_rank, 'port', self._get_master_port())
print('set port', 'rank', trainer.global_rank, 'port', self._get_master_port())
self._set_master_port(port=new_port)

model.init_ddp_connection(trainer.global_rank, trainer.world_size, trainer.is_slurm_managing_tasks)

print('exit')

# s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# #print('shutdown', self._get_master_address(), int(self._get_master_port()))
# s.connect((self._get_master_address(), int(self._get_master_port())))
# s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# #s.shutdown(socket.SHUT_RDWR)
# s.close()
# #sleep(10)

def exit_handler():
if torch.distributed.is_initialized() and trainer.global_rank > 0:
print('destroying on ', trainer.global_rank)
torch.distributed.destroy_process_group()

atexit.register(exit_handler)

def _get_master_port(self):
return os.environ.get('MASTER_PORT')

def _get_master_address(self):
return os.environ.get('MASTER_ADDR')

def _set_master_port(self, port: int = None):
"""
Sets the `MASTER_PORT` environment variable in single-node DDP training.

Args:
port: If provided, sets the environment variable MASTER_PORT, and otherwhise
an attempt is made to find an unused open port.

Return:
The port that was set.
"""
assert self.trainer.num_nodes == 1, 'random port can only be called from single node training'
os.environ['MASTER_PORT'] = str(port or find_open_network_port())
return port


def find_open_network_port():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
24 changes: 12 additions & 12 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,26 @@
import torch
import torch.multiprocessing as mp

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.accelerators.ddp_backend import DistributedConnection
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning import _logger as log

try:
from apex import amp
except ImportError:
amp = None
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True


class DDPSpawnBackend(object):

def __init__(self, trainer):
self.trainer = trainer
self.mp_queue = None
self.distributed_connection = DistributedConnection(trainer)

def setup(self):
self.trainer.set_random_port()

# pass in a state q
smp = mp.get_context('spawn')
self.mp_queue = smp.SimpleQueue()
Expand Down Expand Up @@ -94,11 +95,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

self.distributed_connection.reset_connection(self.trainer, model)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)
Expand Down Expand Up @@ -132,9 +130,11 @@ def ddp_train(self, process_idx, mp_queue, model):
# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)

# AMP -
# AMP
# run through amp wrapper before going to distributed DP
if self.trainer.amp_type == AMPType.APEX:
# TODO: remove with dropping NVIDIA AMP support
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
if self.trainer.use_amp and not native_amp_available:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi
)

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}, ADDR: {os.environ['MASTER_ADDR']}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

def configure_sync_batchnorm(self, model: 'LightningModule') -> 'LightningModule':
Expand Down
28 changes: 1 addition & 27 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,7 @@ def train_fx(trial_hparams, cluster_manager, _):
import re
from abc import ABC, abstractmethod
from distutils.version import LooseVersion
from typing import Union, List, Optional, Callable, Tuple
import subprocess
import sys
from time import sleep
import numpy as np
from os.path import abspath
from pkg_resources import parse_version
from typing import Union, List, Optional, Tuple

import torch
from pytorch_lightning import _logger as log
Expand Down Expand Up @@ -168,10 +162,6 @@ def train_fx(trial_hparams, cluster_manager, _):
else:
XLA_AVAILABLE = True

PID = os.getpid()
RNG1 = np.random.RandomState(PID)
RANDOM_PORTS = RNG1.randint(10000, 19999, 1000)


class TrainerDDPMixin(ABC):

Expand Down Expand Up @@ -397,22 +387,6 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
# don't make this debug... this is good UX
rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')

def set_random_port(self, force=False):
"""
When running DDP NOT managed by SLURM, the ports might collide
"""
# pick a random port first
assert self.num_nodes == 1, 'random port can only be called from single node training'
global RANDOM_PORTS
default_port = RANDOM_PORTS[-1]
RANDOM_PORTS = RANDOM_PORTS[:-1]

# when not forced, use the user port
if not force:
default_port = os.environ.get('MASTER_PORT', default_port)

os.environ['MASTER_PORT'] = str(default_port)

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return
Expand Down
Loading