Skip to content

Commit

Permalink
Support learning rate schedulers in ExpertBackend (#196)
Browse files Browse the repository at this point in the history
* Add empty __init__ to hivemind_cli for correct package discovery

* Support learning rate schedulers in ExpertBackend

* Save/load full expert state

* Don't pass compression to make_empty

* spawn -> fork

* Remove load_expert_states

* Make TaskPoolBase an abstract class

* Output warning if some of the keys in state_dict are missing

Co-authored-by: justheuristic <[email protected]>
  • Loading branch information
mryab and justheuristic authored Apr 2, 2021
1 parent f132294 commit 3024d38
Show file tree
Hide file tree
Showing 21 changed files with 384 additions and 227 deletions.
12 changes: 6 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ jobs:
- checkout
- restore_cache:
keys:
- v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
- py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
- run: pip install -r requirements.txt
- run: pip install -r requirements-dev.txt
- save_cache:
key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
key: py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
paths:
- '~/.cache/pip'
- run:
Expand All @@ -28,11 +28,11 @@ jobs:
- checkout
- restore_cache:
keys:
- v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
- py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
- run: pip install -r requirements.txt
- run: pip install -r requirements-dev.txt
- save_cache:
key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
key: py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
paths:
- '~/.cache/pip'
- run:
Expand All @@ -48,11 +48,11 @@ jobs:
- checkout
- restore_cache:
keys:
- v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
- py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
- run: pip install -r requirements.txt
- run: pip install -r requirements-dev.txt
- save_cache:
key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
key: py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
paths:
- '~/.cache/pip'
- run:
Expand Down
Empty file.
9 changes: 8 additions & 1 deletion hivemind/hivemind_cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from hivemind.server import Server
from hivemind.utils.threading import increase_file_limit
from hivemind.utils.logging import get_logger
from hivemind.server.layers import schedule_name_to_scheduler

logger = get_logger(__name__)

Expand All @@ -28,13 +29,20 @@ def main():
parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')

parser.add_argument('--num_handlers', type=int, default=None, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--max_batch_size', type=int, default=16384, required=False,
help='The total number of examples in the same batch will not exceed this value')
parser.add_argument('--device', type=str, default=None, required=False,
help='all experts will use this device in torch notation; default: cuda if available else cpu')

parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
parser.add_argument('--scheduler', type=str, choices=schedule_name_to_scheduler.keys(), default='none',
help='LR scheduler type to use')
parser.add_argument('--num-warmup-steps', type=int, required=False, help='the number of warmup steps for LR schedule')
parser.add_argument('--num-training-steps', type=int, required=False, help='the total number of steps for LR schedule')

parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
help='one or more peers that can welcome you to the dht, e.g. 1.2.3.4:1337 192.132.231.4:4321')
Expand All @@ -45,7 +53,6 @@ def main():
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
parser.add_argument('--checkpoint_dir', type=Path, required=False, help='Directory to store expert checkpoints')
parser.add_argument('--load_experts', action='store_true', help='Load experts from the checkpoint directory')

# fmt:on
args = vars(parser.parse_args())
Expand Down
142 changes: 48 additions & 94 deletions hivemind/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@

import multiprocessing as mp
import multiprocessing.synchronize
import random
import threading
from contextlib import contextmanager
from functools import partial
from typing import Dict, Optional, Tuple, List
from typing import Dict, Optional, Tuple
from pathlib import Path

import torch

import hivemind
from hivemind.dht import DHT
from hivemind.server.expert_uid import UID_DELIMITER
from hivemind.server.checkpoints import CheckpointSaver, load_weights, dir_is_correct
from hivemind.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
from hivemind.server.checkpoints import CheckpointSaver, load_experts, is_directory
from hivemind.server.connection_handler import ConnectionHandler
from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
from hivemind.server.expert_backend import ExpertBackend
from hivemind.server.layers import name_to_block, name_to_input
from hivemind.server.layers import name_to_block, name_to_input, schedule_name_to_scheduler
from hivemind.server.runtime import Runtime
from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
Expand Down Expand Up @@ -68,11 +67,12 @@ def __init__(
if start:
self.run_in_background(await_ready=True)

@staticmethod
def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
load_experts=False, compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
@classmethod
def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
num_warmup_steps=None, num_training_steps=None, num_handlers=None, max_batch_size=4096, device=None,
no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
"""
Instantiate a server with several identical experts. See argparse comments below for details
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
Expand All @@ -85,16 +85,20 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
:param num_handlers: server will use this many parallel processes to handle incoming requests
:param max_batch_size: total num examples in the same batch will not exceed this value
:param device: all experts will use this device in torch notation; default: cuda if available else cpu
:param optim_cls: uses this optimizer to train all experts
:param scheduler: if not `none`, the name of the expert LR scheduler
:param num_warmup_steps: the number of warmup steps for LR schedule
:param num_training_steps: the total number of steps for LR schedule
:param no_dht: if specified, the server will not be attached to a dht
:param initial_peers: a list of peers that will introduce this node to the dht,\
e.g. ('123.11.22.33:1337', '[fe80::abe2:db1c:be7d:5a85]:4567'), default = no peers
:param dht_port: DHT node will listen on this port, default = find open port
You can then use this node as initial peer for subsequent servers.
:param checkpoint_dir: directory to save expert checkpoints
:param load_experts: whether to load expert checkpoints from checkpoint_dir
:param checkpoint_dir: directory to save and load expert checkpoints
:param compression: if specified, use this compression to pack all inputs, outputs and gradients by all experts
hosted on this server. For a more fine-grained compression, start server in python and specify compression
Expand All @@ -113,23 +117,29 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
dht = hivemind.DHT(initial_peers=initial_peers, start=True, listen_on=dht_endpoint)
logger.info(f"Running DHT node on port {dht.port}, initial peers = {initial_peers}")

if load_experts:
assert dir_is_correct(checkpoint_dir)
assert expert_uids is None, "Can't both load saved experts and create new ones from given UIDs"
expert_uids = [child.name for child in checkpoint_dir.iterdir() if (child / 'checkpoint_last.pt').exists()]
if expert_uids:
logger.info(f"Located checkpoints for experts {expert_uids}, ignoring UID generation options")
else:
logger.info(f"No expert checkpoints found in {checkpoint_dir}, generating...")

assert (expert_pattern is None and num_experts is None) or (expert_uids is None) or (num_experts == 0), \
"Please provide either expert_uids *or* num_experts and expert_pattern, but not both"
assert ((expert_pattern is None and num_experts is None and expert_uids is not None) or
(num_experts is not None and expert_uids is None)), \
"Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both"

# get expert uids if not loaded previously
if expert_uids is None:
assert num_experts is not None, "Please specify either expert_uids or num_experts [and expert_pattern]"
logger.info(f"Generating expert uids from pattern {expert_pattern}")
expert_uids = generate_uids_from_pattern(num_experts, expert_pattern, dht=dht)
if checkpoint_dir is not None:
assert is_directory(checkpoint_dir)
expert_uids = [child.name for child in checkpoint_dir.iterdir() if
(child / 'checkpoint_last.pt').exists()]
total_experts_in_checkpoint = len(expert_uids)
logger.info(f"Located {total_experts_in_checkpoint} checkpoints for experts {expert_uids}")

if total_experts_in_checkpoint > num_experts:
raise ValueError(
f"Found {total_experts_in_checkpoint} checkpoints, but num_experts is set to {num_experts}, "
f"which is smaller. Either increase num_experts or remove unneeded checkpoints.")
else:
expert_uids = []

uids_to_generate = num_experts - len(expert_uids)
if uids_to_generate > 0:
logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))

num_experts = len(expert_uids)
num_handlers = num_handlers if num_handlers is not None else num_experts * 8
Expand All @@ -142,6 +152,8 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
else:
args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)

scheduler = schedule_name_to_scheduler[scheduler]

# initialize experts
experts = {}
for expert_uid in expert_uids:
Expand All @@ -150,15 +162,17 @@ def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = No
args_schema=args_schema,
outputs_schema=hivemind.BatchTensorDescriptor(
hidden_dim, compression=compression),
opt=optim_cls(expert.parameters()),
optimizer=optim_cls(expert.parameters()),
scheduler=scheduler,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
max_batch_size=max_batch_size)

if load_experts:
load_weights(experts, checkpoint_dir)
if checkpoint_dir is not None:
load_experts(experts, checkpoint_dir)

server = Server(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
start=start)
return server
return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
checkpoint_dir=checkpoint_dir, start=start)

def run(self):
"""
Expand Down Expand Up @@ -241,7 +255,7 @@ def shutdown(self):
def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
""" A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
pipe, runners_pipe = mp.Pipe(duplex=True)
runner = mp.get_context("spawn").Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)

try:
runner.start()
Expand Down Expand Up @@ -269,63 +283,3 @@ def _server_runner(pipe, *args, **kwargs):
server.shutdown()
server.join()
logger.info("Server shut down.")


def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None,
attempts_per_expert=10) -> List[str]:
"""
Sample experts from a given pattern, remove duplicates.
:param num_experts: sample this many unique expert uids
:param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
means "sample random experts between myprefix.0.0 and myprefix.255.255;
:param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
:param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
:note: this method is not strictly process-safe. If several servers run it concurrently, they have
a small chance of sampling duplicate expert uids.
"""
remaining_attempts = attempts_per_expert * num_experts
found_uids, attempted_uids = list(), set()

def _generate_uid():
if expert_pattern is None:
return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"

uid = []
for block in expert_pattern.split(UID_DELIMITER):
try:
if '[' not in block and ']' not in block:
uid.append(block)
elif block.startswith('[') and block.endswith(']') and ':' in block:
slice_start, slice_end = map(int, block[1:-1].split(':'))
uid.append(str(random.randint(slice_start, slice_end - 1)))
else:
raise ValueError("Block must be either fixed or a range [from:to]")
except KeyboardInterrupt as e:
raise e
except Exception as e:
raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
return UID_DELIMITER.join(uid)

while remaining_attempts > 0 and len(found_uids) < num_experts:

# 1. sample new expert uids at random
new_uids = []
while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
new_uid = _generate_uid()
remaining_attempts -= 1
if new_uid not in attempted_uids:
attempted_uids.add(new_uid)
new_uids.append(new_uid)

# 2. look into DHT (if given) and remove duplicates
if dht:
existing_expert_uids = {found_expert.uid for found_expert in dht.get_experts(new_uids)
if found_expert is not None}
new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]

found_uids += new_uids

if len(found_uids) != num_experts:
logger.warning(f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
f"{attempts_per_expert * num_experts} attempts")
return found_uids
23 changes: 15 additions & 8 deletions hivemind/server/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import os
import threading
from datetime import datetime
from pathlib import Path
from shutil import copy2
from tempfile import TemporaryDirectory
from typing import Dict
import os

import torch

from hivemind.server.expert_backend import ExpertBackend
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)

def dir_is_correct(directory: Path):

def is_directory(directory: Path):
assert directory is not None
assert directory.exists()
assert directory.is_dir()
Expand All @@ -33,7 +36,7 @@ def copy_tree(src: str, dst: str):
class CheckpointSaver(threading.Thread):
def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
super().__init__()
assert dir_is_correct(checkpoint_dir)
assert is_directory(checkpoint_dir)
self.expert_backends = expert_backends
self.update_period = update_period
self.checkpoint_dir = checkpoint_dir
Expand All @@ -48,21 +51,25 @@ def run(self) -> None:


def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
assert dir_is_correct(checkpoint_dir)
logger.debug(f'Storing experts at {checkpoint_dir.absolute()}')
assert is_directory(checkpoint_dir)
timestamp = datetime.now().isoformat(sep='_')
with TemporaryDirectory() as tmpdirname:
for expert_name, expert_backend in experts.items():
expert_dir = Path(tmpdirname) / expert_name
expert_dir.mkdir()
checkpoint_name = expert_dir / f'checkpoint_{timestamp}.pt'
torch.save(expert_backend.state_dict(), checkpoint_name)
torch.save(expert_backend.get_full_state(), checkpoint_name)
os.symlink(checkpoint_name, expert_dir / 'checkpoint_last.pt')
copy_tree(tmpdirname, str(checkpoint_dir))


def load_weights(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
assert dir_is_correct(checkpoint_dir)
def load_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
assert is_directory(checkpoint_dir)
for expert_name, expert in experts.items():
checkpoints_folder = checkpoint_dir / expert_name
latest_checkpoint = checkpoints_folder / 'checkpoint_last.pt'
expert.load_state_dict(torch.load(latest_checkpoint))
if latest_checkpoint.exists():
expert.load_full_state(torch.load(latest_checkpoint))
else:
logger.warning(f'Failed to load checkpoint for expert {expert_name}')
2 changes: 1 addition & 1 deletion hivemind/server/connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logger = get_logger(__name__)


class ConnectionHandler(mp.Process):
class ConnectionHandler(mp.context.ForkProcess):
"""
A process that accepts incoming requests to experts and submits them into the corresponding TaskPool.
Expand Down
Loading

0 comments on commit 3024d38

Please sign in to comment.