From 4a006ae7cffdab0d031a6b45d822208e38977b17 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 16 Jul 2021 00:35:54 +0300 Subject: [PATCH] Update quickstart tutorials and acknowledgements (#307) - moved MoE from quickstart to a separate tutorial - added basic decentralized training example (new quickstart) - added basic dht tutorial - updated acknowledgements - added publications in README - replaced personal e-mail with a shared one - moved from gitter to discord - updated a more detailed installation guide, including WSL Co-authored-by: Michael Diskin Co-authored-by: Aleksandr Borzunov Co-authored-by: Max Ryabinin Co-authored-by: Alexander Borzunov Co-authored-by: Michael Diskin --- README.md | 110 +++++++++--- docs/index.rst | 2 + docs/user/acknowledgements.md | 20 +-- docs/user/dht.md | 135 ++++++++++++++ docs/user/moe.md | 184 +++++++++++++++++++ docs/user/quickstart.md | 302 +++++++++++++++----------------- hivemind/averaging/training.py | 8 +- hivemind/moe/__init__.py | 2 +- hivemind/moe/server/__init__.py | 3 +- hivemind/optim/simple.py | 19 +- setup.py | 6 +- tests/test_training.py | 25 +-- 12 files changed, 594 insertions(+), 222 deletions(-) create mode 100644 docs/user/dht.md create mode 100644 docs/user/moe.md diff --git a/README.md b/README.md index 035ffc612..5c3e8b817 100644 --- a/README.md +++ b/README.md @@ -1,36 +1,38 @@ ## Hivemind: decentralized deep learning in PyTorch -[![CI status](https://github.com/learning-at-home/hivemind/actions/workflows/run-tests.yml/badge.svg?branch=master)](https://github.com/learning-at-home/hivemind/actions) [![Documentation Status](https://readthedocs.org/projects/learning-at-home/badge/?version=latest)](https://learning-at-home.readthedocs.io/en/latest/?badge=latest) -[![Gitter](https://badges.gitter.im/learning-at-home/hivemind.svg)](https://gitter.im/learning-at-home/hivemind?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) +[![PyPI version](https://img.shields.io/pypi/v/hivemind.svg)](https://pypi.org/project/hivemind/) +[![Discord](https://img.shields.io/static/v1?style=default&label=Discord&logo=discord&message=join)](https://discord.gg/xC7ucM8j) +[![CI status](https://github.com/learning-at-home/hivemind/actions/workflows/run-tests.yml/badge.svg?branch=master)](https://github.com/learning-at-home/hivemind/actions) +![Codecov](https://img.shields.io/codecov/c/github/learning-at-home/hivemind) [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -Hivemind is a PyTorch library to train large neural networks across the Internet. Its intended usage is training a -single Transformer model on hundreds of computers from different universities, companies, and volunteers. +Hivemind is a PyTorch library for decentralized deep learning across the Internet. Its intended usage is training one +large model on hundreds of computers from different universities, companies, and volunteers. ![img](https://i.imgur.com/GPxolxb.gif) ## Key Features -* Train neural networks of arbitrary size: parts of their layers are distributed across the participants. * Distributed training without a master node: Distributed Hash Table allows connecting computers in a decentralized network. * Fault-tolerant backpropagation: forward and backward passes succeed even if some nodes are unresponsive or take too long to respond. -* Decentralized parameter averaging: iteratively aggregate updates from multiple workers without the need to - synchronize across the entire network. +* Decentralized parameter averaging: iteratively aggregate updates from multiple + workers without the need to synchronize across the entire network ([paper](https://arxiv.org/abs/2103.03239)). +* Train neural networks of arbitrary size: parts of their layers are distributed across the participants with the + decentralized mixture-of-experts ([paper](https://arxiv.org/abs/2002.04013)). To learn more about the ideas behind this library, see https://learning-at-home.github.io or read the [NeurIPS 2020 paper](https://arxiv.org/abs/2002.04013). ## Installation -Before installing hivemind, make sure that your environment has Python 3.7+ -and [PyTorch](https://pytorch.org/get-started/locally/#start-locally) with a version at least as new as 1.6.0. +Before installing, make sure that your environment has Python 3.7+ +and [PyTorch](https://pytorch.org/get-started/locally/#start-locally) 1.6.0 or newer. +You can install them either natively or with [Anaconda](https://www.anaconda.com/products/individual). -To start using this library, you can either use the pip package manager or build it from source. Since currently the -release cycle is not established yet, we recommend installing hivemind from source to keep up with the latest bugfixes -and improvements. +You can install [the latest release](https://pypi.org/project/hivemind) with pip or build hivemind from source. ### With pip @@ -42,7 +44,7 @@ pip install hivemind ### From source -To install hivemind from source, simply clone the repository and install +To install hivemind from source, simply run the following: ``` git clone https://github.com/learning-at-home/hivemind.git @@ -53,11 +55,31 @@ pip install . If you would like to verify that your installation is working properly, you can install with `pip install -e .[dev]` instead. Then, you can run the tests with `pytest tests/`. +By default, hivemind uses the precompiled binary of +the [go-libp2p-daemon](https://github.com/learning-at-home/go-libp2p-daemon) library. If you face compatibility issues +or want to build the binary yourself, you can recompile it by running `pip install . --global-option="--buildgo"`. +Before running the compilation, please ensure that your machine has a recent version +of [Go toolchain](https://golang.org/doc/install) (1.15 or higher). + +### System requirements +- __Linux__ is the default OS for which hivemind is developed and tested. We recommend Ubuntu 18.04+ (64-bit), + but other 64-bit distros should work as well. Legacy 32-bit is not recommended. +- __macOS 10.x__ mostly works but requires building hivemind from source, and some edge cases may fail. + To ensure full compatibility, we recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind). +- __Windows 10+ (experimental)__ can run hivemind using [WSL](https://docs.microsoft.com/ru-ru/windows/wsl/install-win10). + You can configure WSL to use GPU following [this guide](https://docs.nvidia.com/cuda/wsl-user-guide/index.html) by NVIDIA. + After the CUDA toolkit is installed you can simply follow the instructions above to install with pip or from source. + ## Documentation -* [Quickstart](https://learning-at-home.readthedocs.io/en/latest/user/quickstart.html): install hivemind, set up a - server and train experts -* Documentation & guides are available at [learning-at-home.readthedocs.io](https://learning-at-home.readthedocs.io) +* The [quickstart tutorial](https://learning-at-home.readthedocs.io/en/latest/user/quickstart.html) walks through installation + and a training a simple neural network with several peers. +* [examples/albert](https://github.com/learning-at-home/hivemind/tree/master/examples/albert) contains the starter kit + and instructions for training a Transformer masked language model collaboratively. +* API reference and additional tutorials are available at [learning-at-home.readthedocs.io](https://learning-at-home.readthedocs.io) + +If you have any questions about installing and using hivemind, you can ask them in +[our Discord chat](https://discord.gg/xC7ucM8j) or file an [issue](https://github.com/learning-at-home/hivemind/issues). ## Contributing @@ -66,7 +88,7 @@ documentation improvements to entirely new features, is equally appreciated. If you want to contribute to hivemind but don't know where to start, take a look at the unresolved [issues](https://github.com/learning-at-home/hivemind/issues). Open a new issue or -join [our chat room](https://gitter.im/learning-at-home/hivemind) in case you want to discuss new functionality or +join [our chat room](https://discord.gg/xC7ucM8j) in case you want to discuss new functionality or report a possible bug. Bug fixes are always welcome, but new features should be preferably discussed with maintainers beforehand. @@ -77,7 +99,7 @@ our [guide](https://learning-at-home.readthedocs.io/en/latest/user/contributing. ## Citation -If you found hivemind or its underlying algorithms useful for your experiments, please cite the following source: +If you found hivemind or its underlying algorithms useful for your research, please cite the relevant papers: ``` @misc{hivemind, @@ -88,7 +110,8 @@ If you found hivemind or its underlying algorithms useful for your experiments, } ``` -Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired the creation of this library: +Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired the creation of this library +(prototype implementation of hivemind available at [mryab/learning-at-home](https://github.com/mryab/learning-at-home)): ``` @inproceedings{ryabinin2020crowdsourced, @@ -104,10 +127,49 @@ Also, you can cite [the paper](https://arxiv.org/abs/2002.04013) that inspired t } ``` -The initial implementation of hivemind used for the paper is available -at [mryab/learning-at-home](https://github.com/mryab/learning-at-home). +
+ Additional publications + +["Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices"](https://arxiv.org/abs/2103.03239) + +``` +@misc{ryabinin2021moshpit, + title={Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices}, + author={Max Ryabinin and Eduard Gorbunov and Vsevolod Plokhotnyuk and Gennady Pekhimenko}, + year={2021}, + eprint={2103.03239}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` + +["Distributed Deep Learning in Open Collaborations"](https://arxiv.org/abs/2106.10207) + +``` +@misc{diskin2021distributed, + title={Distributed Deep Learning in Open Collaborations}, + author={Michael Diskin and Alexey Bukhtiyarov and Max Ryabinin and Lucile Saulnier and Quentin Lhoest and Anton Sinitsin and Dmitry Popov and Dmitry Pyrkin and Maxim Kashirin and Alexander Borzunov and Albert Villanova del Moral and Denis Mazur and Ilia Kobelev and Yacine Jernite and Thomas Wolf and Gennady Pekhimenko}, + year={2021}, + eprint={2106.10207}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` + +["Secure Distributed Training at Scale"](https://arxiv.org/abs/2106.11257) + +``` +@misc{gorbunov2021secure, + title={Secure Distributed Training at Scale}, + author={Eduard Gorbunov and Alexander Borzunov and Michael Diskin and Max Ryabinin}, + year={2021}, + eprint={2106.11257}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` -In the documentation, we list -several [related](https://learning-at-home.readthedocs.io/en/latest/user/acknowledgements.html) projects and -acknowledgements. +
+We also maintain a list of [related projects and +acknowledgements](https://learning-at-home.readthedocs.io/en/latest/user/acknowledgements.html). diff --git a/docs/index.rst b/docs/index.rst index e4b04503c..31fba5c4d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,6 +21,8 @@ documentation below. user/quickstart modules/index + user/dht + user/moe user/contributing user/benchmarks user/acknowledgements diff --git a/docs/user/acknowledgements.md b/docs/user/acknowledgements.md index 770d5eeec..5546ce6f2 100644 --- a/docs/user/acknowledgements.md +++ b/docs/user/acknowledgements.md @@ -1,6 +1,6 @@ -# Credits +# Acknowledgements -We kindly thank (in random order) +We kindly thank (in no particular order) * [Artem Babenko](https://research.yandex.com/people/102794) and [Vladimir Aliev](https://ru.linkedin.com/in/vladimir-aliev-19b93282) for helpful discussions and editorial review of @@ -14,15 +14,15 @@ We kindly thank (in random order) * [Brian Muller](https://github.com/bmuller/kademlia) for his implementations of [kademlia](https://github.com/bmuller/kademlia) and [rpcudp](https://github.com/bmuller/rpcudp) * Alexander Sherbakov for helpful discussions on PC and server component architecture, -* Our early adopters, [contributors](https://github.com/learning-at-home/hivemind/graphs/contributors), and reviewers +* [Yandex School of Data Analysis](https://yandexdataschool.com) students, for helping us run first truly collaborative experiments. +* The [Neuropark community](https://neuropark.co/), for hosting early collaborative training experiments of sahajBERT with hivemind. +* Our early adopters, [contributors](https://github.com/learning-at-home/hivemind/graphs/contributors), and conference reviewers. # Related projects -We also want to reference several projects that have similar ideas in mind: - +In this section, we list several organizations and research projects that bring humanity closer to the dream of world-scale deep learning with volunteer computing. +* [Hugging Face](https://huggingface.co) — an AI community with world-leading NLP research that builds collaborative hub training using hivemind. +* [EYDLE](https://www.eydle.com) — a start-up that works towards distributed deep learning on volunteer hardware using centralized infrastructure. * [BitTensor](https://github.com/opentensor/BitTensor) — a decentralized deep learning ecosystem with incentive - mechanism. Like hivemind, but peers are getting rewarded for their contribution to other peers. . -* [GShard](https://arxiv.org/abs/2006.16668) — a paper by Dmitry Lepikhin et al. that demonstrate the effectiveness of - huge Mixture-of-Experts models on conventional hpc hardware. Those guys train models 4 times the size of GPT-3 on - thousands of TPUv3. -* Also doing research in decentralized deep learning? Let us know! \ No newline at end of file + mechanism. Each peer trains for its own objective and rewards others for useful features. +* Also building collaborative deep learning? Let us know! `hivemind-team hotmail.com` diff --git a/docs/user/dht.md b/docs/user/dht.md new file mode 100644 index 000000000..fc6199975 --- /dev/null +++ b/docs/user/dht.md @@ -0,0 +1,135 @@ +# Hivemind DHT + +In order to coordinate, hivemind peers form a Distributed Hash Table: distributed "dictionary" where each peer +can store and get values. To initialize the first DHT node, run + +```python +from hivemind import DHT, get_dht_time + +dht = DHT(start=True) +# create the first DHT node that listens for incoming connections from localhost only + +print("For incoming connections, use:", dht.get_visible_maddrs()) +``` + +You can now start more peers that connect to an existing DHT node using its listen address: +```python +dht2 = DHT(initial_peers=dht.get_visible_maddrs(), start=True) +``` + +Note that `initial_peers` contains the address of the first DHT node. +This implies that the resulting node will have shared key-value with the first node, __as well as any other +nodes connected to it.__ When the two nodes are connected, subsequent peers can use any one of them (or both) +as `initial_peers` to connect to the shared "dictionary". + +### Store/get operations + +Once the DHT is formed, all participants can `dht.store` key-value pairs in the DHT and `dht.get` them by key: + +```python +# first node: store a key-value pair for 600 seconds +store_ok = dht.store('my_key', ('i', 'love', 'bees'), + expiration_time=get_dht_time() + 600) + +# second node: get the value stored by the first node +value, expiration = dht2.get('my_key', latest=True) +assert value == ('i', 'love', 'bees') +``` + +As you can see, each value in a hivemind DHT is associated with an expiration time, +computed current `get_dht_time()` with some offset. +This expiration time is used to cleanup old data and resolve write conflicts: +DHT nodes always prefer values with higher expiration time and may delete any value past its expiration. + +### Values with subkeys + +Hivemind DHT also supports a special value type that is itself a dictionary. When nodes store such a value, +they add sub-keys to the dictionary instead of overwriting it. + +Consider an example where three DHT nodes want to find out who is going to attend the party: + +```python +alice_dht = DHT(initial_peers=dht.get_visible_maddrs(), start=True) +bob_dht = DHT(initial_peers=dht2.get_visible_maddrs(), start=True) +carol_dht = DHT(initial_peers=alice_dht.get_visible_maddrs(), start=True) + + +# first, each peer stores a subkey for the same key +alice_dht.store('party', subkey='alice', value='yes', expiration_time=get_dht_time() + 600) +bob_dht.store('party', subkey='bob', value='yes', expiration_time=get_dht_time() + 600) +carol_dht.store('party', subkey='carol', value='no', expiration_time=get_dht_time() + 600) + +# then, any peer can get the full list of attendees +attendees, expiration = alice_dht.get('party', latest=True) +print(attendees) +# {'alice': ValueWithExpiration(value='yes', expiration_time=1625504352.2668974), +# 'bob': ValueWithExpiration(value='yes', expiration_time=1625504352.2884178), +# 'carol': ValueWithExpiration(value='no', expiration_time=1625504352.3046832)} +``` + +When training over the Internet, some `dht.get/store` requests may run for hundreds of milliseconds and even seconds. +To minimize the wait time, you can call these requests asynchronously via +[`dht.store/get/run_coroutine(..., return_future=True)`__](https://learning-at-home.readthedocs.io/en/latest/modules/dht.html#hivemind.dht.DHT.get) +. This will run the corresponding command in background and return a [Future-like](https://docs.python.org/3/library/concurrent.futures.html) object that can be awaited. +Please also note that the returned future is compatible with asyncio (i.e. can be awaited inside the event loop). + +For more details on DHT store/get and expiration time, please refer to the [documentation for DHT and DHTNode](https://learning-at-home.readthedocs.io/en/latest/modules/dht.html#dht-and-dhtnode) + +### Running across the Internet + +By default, DHT nodes are only accessible from your localhost. In order to run with multiple geographically +distributed computers, one must connect DHT to a global network. Currently, there are two ways achieve this. + +The recommended approach is to grow the network from one or several initial peers. These can be any computers with a +public IP address that are always online. Each of these peers should simply create `hivemind.DHT` and set it to +accept incoming connections from the internet: + +```python +import hivemind +dht = hivemind.DHT( + host_maddrs=["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"], + start=True) + +print('\n'.join(str(addr) for addr in dht.get_visible_maddrs())) +print("Global IP:", hivemind.utils.networking.choose_ip_address(dht.get_visible_maddrs())) +``` + +Running this code will print several, typically, 4 or 6 strings of the following form (example): +```shell +/ip4/185.185.123.124/tcp/40615/p2p/QmaVTB2LwayToK2rzMkaCbkCaH7nF2rTHIS0IS0AN0EXAMPLE +/ip4/127.0.0.1/tcp/40615/p2p/QmaVTB2LwayToK2rzMkaCbkCaH7nF2rTHIS0IS0AN0EXAMPLE +/ip4/185.185.123.124/udp/40346/quic/p2p/QmaVTB2LwayToK2rzMkaCbkCaH7nF2rTHIS0IS0AN0EXAMPLE +/ip4/127.0.0.1/udp/40346/quic/p2p/QmaVTB2LwayToK2rzMkaCbkCaH7nF2rTHIS0IS0AN0EXAMPLE +Global IP: 185.185.123.124 +``` +The lines that contain addresses that other nodes can use to connect to the network: +- `127.0.0.1` or `192.168.X.Y` are only accessible from your computer or local network, respectively. +- The remaining address is __global__ (`185.185.123.124` in the example, yours will be different). + +To connect a new peer to the network, you should specify `initial_peers` as the addresses that +correspond to the public IP: + +```python +import hivemind +dht = hivemind.DHT( + host_maddrs=["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"], + initial_peers=[ + "/ip4/185.185.123.124/tcp/40615/p2p/QmaVTB2LwayToK2rzMkaCbkCaH7nF2rTHIS0IS0AN0EXAMPLE", + "/ip4/185.185.123.124/udp/40346/quic/p2p/QmaVTB2LwayToK2rzMkaCbkCaH7nF2rTHIS0IS0AN0EXAMPLE", + ], start=True) +``` + +Thats it, now the two DHT nodes are connected. If you connect additional peers to the network, you only need to specify +one (or a subset) of peers as `initial_peers`. +In case your peer operates behind a restrictive firewall, you may find it beneficial to set `client_mode=True`. In this + case, the DHT instance will access others, but it will not announce that other peers can connect to it. + +Another (experimental) way is to use [IPFS](https://ipfs.io/): a global decentralized network for file storage. +We are not storing any files here: instead, we can use IPFS nodes to help hivemind peers find each other. +To use this strategy, set `use_ipfs=True` in each DHT node you create. This allows you to connect DHT multiple even if +all of them are behind NAT. However, this strategy may be unreliable and depend heavily on the availability of public +IPFS nodes. + +To learn more about the network address format, read [libp2p addressing](https://docs.libp2p.io/concepts/addressing/) +For an example of how to set up DHT in a distributed training experiment, see + [examples/albert](https://github.com/learning-at-home/hivemind/tree/master/examples/albert) diff --git a/docs/user/moe.md b/docs/user/moe.md new file mode 100644 index 000000000..d20ef7a61 --- /dev/null +++ b/docs/user/moe.md @@ -0,0 +1,184 @@ +# Mixture-of-Experts + +This tutorial covers the basics of Decentralized Mixture-of-Experts (DMoE). +From the infrastructure standpoint, DMoE consists of two parts: experts hosted on peer devices, and a gating/routing function that assigns input to one of these experts. + +## Host experts with a server + +`hivemind.moe.Server` hosts one or several experts (PyTorch modules) for remote access. These experts are responsible for +most of the model parameters and computation. The server can be started using either Python or +[a shell script](https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_server.py). We'll use the shell +for now. To host a server with default experts, run this in your shell: + +```sh +hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_pattern "expert.[0:5]" \ + --listen_on 0.0.0.0:1337 +# note: if you omit listen_on and/or dht_port, they will be chosen automatically and printed to stdout. +``` + +
+ Console outputs + +```sh +[2021/07/15 18:52:01.424][INFO][moe.server.create:156] Running DHT node on ['/ip4/127.0.0.1/tcp/42513/p2p/QmacLgRkAHSqdWYdQ8TePioMxQCNV2JeD3AUDmbVd69gNL'], initial peers = [] +[2021/07/15 18:52:01.424][INFO][moe.server.create:181] Generating 5 expert uids from pattern expert.[0:5] +[2021/07/15 18:52:01.658][INFO][moe.server.run:233] Server started at 0.0.0.0:1337 +[2021/07/15 18:52:01.658][INFO][moe.server.run:234] Got 5 experts: +[2021/07/15 18:52:01.658][INFO][moe.server.run:237] expert.4: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:52:01.658][INFO][moe.server.run:237] expert.0: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:52:01.659][INFO][moe.server.run:237] expert.3: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:52:01.659][INFO][moe.server.run:237] expert.2: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:52:01.659][INFO][moe.server.run:237] expert.1: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:52:02.447][INFO][moe.server.task_pool.run:145] expert.4_forward starting, pid=14038 +[2021/07/15 18:52:02.468][INFO][moe.server.task_pool.run:145] expert.4_backward starting, pid=14042 +[2021/07/15 18:52:02.469][INFO][moe.server.task_pool.run:145] expert.0_forward starting, pid=14044 +[2021/07/15 18:52:02.484][INFO][moe.server.task_pool.run:145] expert.0_backward starting, pid=14051 +[2021/07/15 18:52:02.501][INFO][moe.server.task_pool.run:145] expert.3_forward starting, pid=14057 +[2021/07/15 18:52:02.508][INFO][moe.server.task_pool.run:145] expert.3_backward starting, pid=14058 +[2021/07/15 18:52:02.508][INFO][moe.server.task_pool.run:145] expert.2_forward starting, pid=14060 +[2021/07/15 18:52:02.521][INFO][moe.server.task_pool.run:145] expert.2_backward starting, pid=14070 +[2021/07/15 18:52:02.521][INFO][moe.server.task_pool.run:145] expert.1_forward starting, pid=14075 +[2021/07/15 18:52:02.532][INFO][moe.server.task_pool.run:145] expert.1_backward starting, pid=14081 +[2021/07/15 18:52:02.532][INFO][moe.server.runtime.run:80] Started +``` + +
+ + +This server serves 5 feedforward experts with ReLU and LayerNorm +(see +architecture [here](https://github.com/learning-at-home/hivemind/blob/master/hivemind/server/layers/__init__.py#L7-L21)) +. In order to connect to this server, you should copy its address from console outputs: +```shell +[...][INFO][moe.server.create:156] Running DHT node on ['ADDRESS_WILL_BE_PRINTED_HERE'] +``` + + +You can create additional servers in the same decentralized network using the `--initial_peers` argument: + +```sh +hivemind-server --expert_cls ffn --hidden_dim 512 --num_experts 10 --expert_pattern "expert.[5:250]" \ + --initial_peers /ip4/127.0.0.1/tcp/42513/p2p/COPY_FULL_ADDRESS_HERE +``` + +
+ Console outputs + +```sh +[2021/07/15 18:53:41.700][INFO][moe.server.create:156] Running DHT node on ['/ip4/127.0.0.1/tcp/34487/p2p/QmcJ3jgbdwphLAiwGjvwrjimJJrdMyhLHf6tFj9viCFFGn'], initial peers = ['/ip4/127.0.0.1/tcp/42513/p2p/QmacLgRkAHSqdWYdQ8TePioMxQCNV2JeD3AUDmbVd69gNL'] +[2021/07/15 18:53:41.700][INFO][moe.server.create:181] Generating 10 expert uids from pattern expert.[5:250] +[2021/07/15 18:53:42.085][INFO][moe.server.run:233] Server started at 0.0.0.0:36389 +[2021/07/15 18:53:42.086][INFO][moe.server.run:234] Got 10 experts: +[2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.55: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.173: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.164: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.99: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:53:42.086][INFO][moe.server.run:237] expert.149: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:53:42.087][INFO][moe.server.run:237] expert.66: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:53:42.087][INFO][moe.server.run:237] expert.106: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:53:42.087][INFO][moe.server.run:237] expert.31: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:53:42.087][INFO][moe.server.run:237] expert.95: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:53:42.087][INFO][moe.server.run:237] expert.167: FeedforwardBlock, 2100736 parameters +[2021/07/15 18:53:43.892][INFO][moe.server.task_pool.run:145] expert.55_forward starting, pid=14854 +[2021/07/15 18:53:43.901][INFO][moe.server.task_pool.run:145] expert.55_backward starting, pid=14858 +[2021/07/15 18:53:43.915][INFO][moe.server.task_pool.run:145] expert.173_forward starting, pid=14862 +[2021/07/15 18:53:43.929][INFO][moe.server.task_pool.run:145] expert.173_backward starting, pid=14864 +[2021/07/15 18:53:43.930][INFO][moe.server.task_pool.run:145] expert.164_forward starting, pid=14869 +[2021/07/15 18:53:43.948][INFO][moe.server.task_pool.run:145] expert.164_backward starting, pid=14874 +[2021/07/15 18:53:43.968][INFO][moe.server.task_pool.run:145] expert.99_forward starting, pid=14883 +[2021/07/15 18:53:43.977][INFO][moe.server.task_pool.run:145] expert.99_backward starting, pid=14888 +[2021/07/15 18:53:43.995][INFO][moe.server.task_pool.run:145] expert.149_forward starting, pid=14889 +[2021/07/15 18:53:44.007][INFO][moe.server.task_pool.run:145] expert.149_backward starting, pid=14898 +[2021/07/15 18:53:44.021][INFO][moe.server.task_pool.run:145] expert.66_forward starting, pid=14899 +[2021/07/15 18:53:44.034][INFO][moe.server.task_pool.run:145] expert.106_forward starting, pid=14909 +[2021/07/15 18:53:44.036][INFO][moe.server.task_pool.run:145] expert.66_backward starting, pid=14904 +[2021/07/15 18:53:44.058][INFO][moe.server.task_pool.run:145] expert.106_backward starting, pid=14919 +[2021/07/15 18:53:44.077][INFO][moe.server.task_pool.run:145] expert.31_forward starting, pid=14923 +[2021/07/15 18:53:44.077][INFO][moe.server.task_pool.run:145] expert.31_backward starting, pid=14925 +[2021/07/15 18:53:44.095][INFO][moe.server.task_pool.run:145] expert.95_forward starting, pid=14932 +[2021/07/15 18:53:44.106][INFO][moe.server.task_pool.run:145] expert.95_backward starting, pid=14935 +[2021/07/15 18:53:44.118][INFO][moe.server.task_pool.run:145] expert.167_forward starting, pid=14943 +[2021/07/15 18:53:44.119][INFO][moe.server.task_pool.run:145] expert.167_backward starting, pid=14944 +[2021/07/15 18:53:44.123][INFO][moe.server.runtime.run:80] Started +``` + +
+ +By default, the server will only accept connections from your local machine. To access it globally, you should replace +`127.0.0.1` part from initial peers with server's IP address. Hivemind supports both ipv4 and ipv6 protocols and uses the same notation +as [libp2p](https://docs.libp2p.io/concepts/addressing/). You can find more details on multiaddresses in the +[DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html). + +## Train the experts + +Now let's put these experts to work. Create a python console (or a jupyter) and run: + +```python +import torch +import hivemind + +dht = hivemind.DHT( + initial_peers=["/ip4/127.0.0.1/tcp/TODO/COPYFULL_ADDRESS/FROM_ONE_OF_THE_SERVERS"], + client_mode=True, start=True) + +# note: client_mode=True means that your peer will operate in a "client-only" mode: +# this means that it can request other peers, but will not accept requests in return + +expert1, expert4 = hivemind.moe.get_experts(dht, ["expert.1", "expert.4"]) +assert expert1 is not None and expert4 is not None, "experts not found. Please double-check initial peers" +``` + +Each expert (e.g. `expert1`) can be used as a pytorch module with autograd support: + +```python +dummy = torch.randn(3, 512) +out = expert1(dummy) # forward pass +out.sum().backward() # backward pass +``` + +When called, `expert1` will submit a request to the corresponding server (which you created above) and return the output +tensor(s) or raise an exception. During backward, pytorch will submit the backward requests for the experts as they +appear in the computation graph. + +By default, the experts will automatically update their parameters with one step of SGD after each backward pass. This +allows you to quickly run training using both local and remote layers: + +```python +# generate dummy data +x = torch.randn(3, 512) +y = 0.01 * x.sum(dim=-1, keepdim=True) + +# local torch module +proj_out = torch.nn.Sequential( + torch.nn.Linear(512, 3) +) +opt = torch.optim.SGD(proj_out.parameters(), lr=0.01) + +for i in range(100): + prediction = proj_out(expert1(expert4(x))) + loss = torch.mean(abs(prediction - y)) + print(loss.item()) + opt.zero_grad() + loss.backward() + opt.step() +``` + +Finally, you can create a Mixture-of-Experts layer over these experts: + +```python +import nest_asyncio; nest_asyncio.apply() # asyncio patch for jupyter. for now, we recommend using MoE from console + +dmoe = hivemind.RemoteMixtureOfExperts(in_features=512, uid_prefix="expert.", grid_size=(5,), + dht=dht, k_best=2) + +out = dmoe(torch.randn(3, 512)) +out.sum().backward() +``` + +The `dmoe` layer dynamically selects the right experts using a linear gating function. It will then dispatch parallel +forward (and backward) requests to those experts and collect results. You can find more details on how DMoE works in +Section 2.3 of [(Ryabinin et al, 2020)](https://arxiv.org/abs/2002.04013). In addition to traditional MoE, hivemind +implements `hivemind.RemoteSwitchMixtureOfExperts` using the simplified routing algorithm [(Fedus et al 2021)](https://arxiv.org/abs/2101.03961). + +For more code examples related to DMoE, such as defining custom experts or using switch-based routing, please refer to +[`hivemind/tests/test_training.py`](https://github.com/learning-at-home/hivemind/blob/master/tests/test_training.py). diff --git a/docs/user/quickstart.md b/docs/user/quickstart.md index 152cb3192..a55739151 100644 --- a/docs/user/quickstart.md +++ b/docs/user/quickstart.md @@ -4,212 +4,184 @@ This tutorial will teach you how to install `hivemind`, host your own experts an ## Installation -Just `pip install hivemind` to get the latest release. +Just `pip install hivemind` to get the latest release (requires Python 3.7 or newer). You can also install the bleeding edge version from GitHub: ``` git clone https://github.com/learning-at-home/hivemind cd hivemind -pip install . +pip install -e . ``` + +## Decentralized Training -You can also install it in the editable mode with `pip install -e .`. +Hivemind is a set of building blocks for decentralized training. +In this tutorial, we'll use two of these blocks to train a simple neural network to classify CIFAR-10 images. +We assume that you are already familiar with the official [CIFAR-10 example](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) +from the PyTorch website. -* __Dependencies:__ Hivemind requires Python 3.7+. - The [requirements](https://github.com/learning-at-home/hivemind/blob/master/requirements.txt) are installed - automatically. -* __OS support:__ Linux and macOS should just work. We do not officially support Windows, but you are welcome to - contribute your windows build :) +We build on top of the official example to spin up distributed training of a two-layer neural network by averaging weights. +For simplicity, this tutorial will use two non-GPU peers running on the same machine. If you get to the end of this +tutorial, we'll give you an example of actual distributed training of Transformers ;) -## Host a server - -`hivemind.moe.Server` hosts one or several experts (PyTorch modules) for remote access. These experts are responsible for -most of the model parameters and computation. The server can be started using either Python or -[a shell script](https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_server.py). We'll use the shell -for now. To host a server with default experts, run this in your shell: +For now, let's run our first training peer: +```python +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms +from tqdm.auto import tqdm -```sh -python hivemind/hivemind_cli/run_server.py --expert_cls ffn --hidden_dim 512 --num_experts 5 --expert_pattern "expert.[0:5]" \ - --listen_on 0.0.0.0:1337 --dht_port 1338 -# note: if you omit listen_on and/or dht_port, they will be chosen automatically and printed to stdout. -``` +import hivemind -
- Console outputs - -```sh -[2020/08/26 11:54:52.645][INFO][server.create:101] Bootstrapping DHT node, initial peers = [] -[2020/08/26 11:54:52.660][INFO][server.create:105] Running dht node on port 1338 -[2020/08/26 11:54:53.182][INFO][server.task_pool.run:130] expert.0_forward starting, pid=19382 -[2020/08/26 11:54:53.182][INFO][server.task_pool.run:130] expert.0_forward starting, pid=19382 -[2020/08/26 11:54:53.189][INFO][server.task_pool.run:130] expert.0_backward starting, pid=19384 -[2020/08/26 11:54:53.189][INFO][server.task_pool.run:130] expert.0_backward starting, pid=19384 -[2020/08/26 11:54:53.196][INFO][server.task_pool.run:130] expert.1_forward starting, pid=19386 -[2020/08/26 11:54:53.196][INFO][server.task_pool.run:130] expert.1_forward starting, pid=19386 -[2020/08/26 11:54:53.206][INFO][server.task_pool.run:130] expert.1_backward starting, pid=19388 -[2020/08/26 11:54:53.206][INFO][server.task_pool.run:130] expert.1_backward starting, pid=19388 -[2020/08/26 11:54:53.212][INFO][server.task_pool.run:130] expert.2_forward starting, pid=19390 -[2020/08/26 11:54:53.212][INFO][server.task_pool.run:130] expert.2_forward starting, pid=19390 -[2020/08/26 11:54:53.218][INFO][server.task_pool.run:130] expert.2_backward starting, pid=19392 -[2020/08/26 11:54:53.218][INFO][server.task_pool.run:130] expert.2_backward starting, pid=19392 -[2020/08/26 11:54:53.225][INFO][server.task_pool.run:130] expert.3_forward starting, pid=19394 -[2020/08/26 11:54:53.225][INFO][server.task_pool.run:130] expert.3_forward starting, pid=19394 -[2020/08/26 11:54:53.232][INFO][server.task_pool.run:130] expert.3_backward starting, pid=19396 -[2020/08/26 11:54:53.232][INFO][server.task_pool.run:130] expert.3_backward starting, pid=19396 -[2020/08/26 11:54:53.235][INFO][server.task_pool.run:130] expert.4_forward starting, pid=19398 -[2020/08/26 11:54:53.235][INFO][server.task_pool.run:130] expert.4_forward starting, pid=19398 -[2020/08/26 11:54:53.241][INFO][server.task_pool.run:130] expert.4_backward starting, pid=19400 -[2020/08/26 11:54:53.241][INFO][server.task_pool.run:130] expert.4_backward starting, pid=19400 -[2020/08/26 11:54:53.244][INFO][server.runtime.run:60] Started -[2020/08/26 11:54:53.244][INFO][server.runtime.run:60] Started -[2020/08/26 11:54:53.245][INFO][server.create:136] Server started at 0.0.0.0:1337 -[2020/08/26 11:54:53.245][INFO][server.create:137] Got 5 active experts of type ffn: ['expert.0', 'expert.1', 'expert.2', 'expert.3', 'expert.4'] -``` +# Create dataset and model, same as in the basic tutorial +# For this basic tutorial, we download only the training set +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) -
+trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) +model = nn.Sequential(nn.Conv2d(3, 16, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(), + nn.Conv2d(16, 32, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(), + nn.Flatten(), nn.Linear(32 * 5 * 5, 10)) +opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) -This server accepts requests to experts on port 1337 and start a DHT peer on port 1338. In total, it serves 5 -feedforward experts with ReLU and LayerNorm -(see -architecture [here](https://github.com/learning-at-home/hivemind/blob/master/hivemind/server/layers/__init__.py#L7-L21)) -. -You can create additional servers in the same decentralized network using `--initial_peers` argument: +# Create DHT: a decentralized key-value storage shared between peers +dht = hivemind.DHT(start=True) +print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()]) -```sh -python hivemind/hivemind_cli/run_server.py --expert_cls ffn --hidden_dim 512 --num_experts 10 --expert_pattern "expert.[5:250]" \ - --initial_peers localhost:1338 +# Set up a decentralized optimizer that will average with peers in background +opt = hivemind.optim.DecentralizedOptimizer( + opt, # wrap the SGD optimizer defined above + dht, # use a DHT that is connected with other peers + average_parameters=True, # periodically average model weights in opt.step + average_gradients=False, # do not average accumulated gradients + prefix='my_cifar_run', # unique identifier of this collaborative run + target_group_size=16, # maximum concurrent peers for this run + verbose=True # print logs incessently +) +# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created + +with tqdm() as progressbar: + while True: + for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=256): + opt.zero_grad() + loss = F.cross_entropy(model(x_batch), y_batch) + loss.backward() + opt.step() + + progressbar.desc = f"loss = {loss.item():.3f}" + progressbar.update() ``` -
- Console outputs - -```sh -[2020/08/26 13:15:05.078][INFO][server.create:103] Bootstrapping DHT node, initial peers = ['localhost:1338'] -[2020/08/26 13:15:05.101][INFO][server.create:107] Running dht node on port 44291 -expert.[5:250] -[2020/08/26 13:15:06.326][INFO][server.task_pool.run:130] expert.113_forward starting, pid=29517 -[2020/08/26 13:15:06.326][INFO][server.task_pool.run:130] expert.113_forward starting, pid=29517 -[2020/08/26 13:15:06.333][INFO][server.task_pool.run:130] expert.113_backward starting, pid=29519 -[2020/08/26 13:15:06.333][INFO][server.task_pool.run:130] expert.113_backward starting, pid=29519 -[2020/08/26 13:15:06.340][INFO][server.task_pool.run:130] expert.149_forward starting, pid=29521 -[2020/08/26 13:15:06.340][INFO][server.task_pool.run:130] expert.149_forward starting, pid=29521 -[2020/08/26 13:15:06.352][INFO][server.task_pool.run:130] expert.149_backward starting, pid=29523 -[2020/08/26 13:15:06.352][INFO][server.task_pool.run:130] expert.149_backward starting, pid=29523 -[2020/08/26 13:15:06.363][INFO][server.task_pool.run:130] expert.185_forward starting, pid=29525 -[2020/08/26 13:15:06.363][INFO][server.task_pool.run:130] expert.185_forward starting, pid=29525 -[2020/08/26 13:15:06.375][INFO][server.task_pool.run:130] expert.185_backward starting, pid=29527 -[2020/08/26 13:15:06.375][INFO][server.task_pool.run:130] expert.185_backward starting, pid=29527 -[2020/08/26 13:15:06.381][INFO][server.task_pool.run:130] expert.189_forward starting, pid=29529 -[2020/08/26 13:15:06.381][INFO][server.task_pool.run:130] expert.189_forward starting, pid=29529 -[2020/08/26 13:15:06.388][INFO][server.task_pool.run:130] expert.189_backward starting, pid=29531 -[2020/08/26 13:15:06.388][INFO][server.task_pool.run:130] expert.189_backward starting, pid=29531 -[2020/08/26 13:15:06.400][INFO][server.task_pool.run:130] expert.191_forward starting, pid=29533 -[2020/08/26 13:15:06.400][INFO][server.task_pool.run:130] expert.191_forward starting, pid=29533 -[2020/08/26 13:15:06.407][INFO][server.task_pool.run:130] expert.191_backward starting, pid=29535 -[2020/08/26 13:15:06.407][INFO][server.task_pool.run:130] expert.191_backward starting, pid=29535 -[2020/08/26 13:15:06.415][INFO][server.task_pool.run:130] expert.196_forward starting, pid=29537 -[2020/08/26 13:15:06.415][INFO][server.task_pool.run:130] expert.196_forward starting, pid=29537 -[2020/08/26 13:15:06.426][INFO][server.task_pool.run:130] expert.196_backward starting, pid=29539 -[2020/08/26 13:15:06.426][INFO][server.task_pool.run:130] expert.196_backward starting, pid=29539 -[2020/08/26 13:15:06.435][INFO][server.task_pool.run:130] expert.225_forward starting, pid=29541 -[2020/08/26 13:15:06.435][INFO][server.task_pool.run:130] expert.225_forward starting, pid=29541 -[2020/08/26 13:15:06.445][INFO][server.task_pool.run:130] expert.225_backward starting, pid=29543 -[2020/08/26 13:15:06.445][INFO][server.task_pool.run:130] expert.225_backward starting, pid=29543 -[2020/08/26 13:15:06.454][INFO][server.task_pool.run:130] expert.227_forward starting, pid=29545 -[2020/08/26 13:15:06.454][INFO][server.task_pool.run:130] expert.227_forward starting, pid=29545 -[2020/08/26 13:15:06.467][INFO][server.task_pool.run:130] expert.227_backward starting, pid=29547 -[2020/08/26 13:15:06.467][INFO][server.task_pool.run:130] expert.227_backward starting, pid=29547 -[2020/08/26 13:15:06.475][INFO][server.task_pool.run:130] expert.36_forward starting, pid=29549 -[2020/08/26 13:15:06.475][INFO][server.task_pool.run:130] expert.36_forward starting, pid=29549 -[2020/08/26 13:15:06.482][INFO][server.task_pool.run:130] expert.36_backward starting, pid=29551 -[2020/08/26 13:15:06.482][INFO][server.task_pool.run:130] expert.36_backward starting, pid=29551 -[2020/08/26 13:15:06.497][INFO][server.task_pool.run:130] expert.58_forward starting, pid=29553 -[2020/08/26 13:15:06.497][INFO][server.task_pool.run:130] expert.58_forward starting, pid=29553 -[2020/08/26 13:15:06.507][INFO][server.task_pool.run:130] expert.58_backward starting, pid=29555 -[2020/08/26 13:15:06.507][INFO][server.task_pool.run:130] expert.58_backward starting, pid=29555 -[2020/08/26 13:15:06.509][INFO][server.runtime.run:60] Started -[2020/08/26 13:15:06.509][INFO][server.runtime.run:60] Started -[2020/08/26 13:15:06.510][INFO][server.create:166] Server started at 0.0.0.0:40089 -[2020/08/26 13:15:06.510][INFO][server.create:167] Got 10 active experts of type ffn: ['expert.113', 'expert.149', 'expert.185', 'expert.189', 'expert.191', 'expert.196', 'expert.225', 'expert.227', 'expert.36', 'expert.58'] + +As you can see, this code is regular PyTorch with one notable exception: it wraps your regular optimizer with a +`DecentralizedOptimizer`. This optimizer uses `DHT` to find other peers and tries to exchange weights them. When you run +the code (please do so), you will see the following output: + +```shell +To join the training, use initial_peers = ['/ip4/127.0.0.1/tcp/XXX/p2p/YYY'] +[...] Starting a new averaging round with current parameters. ``` -
+This is `DecentralizedOptimizer` telling you that it's looking for peers. Since there are no peers, we'll need to create +them ourselves. -Here and below, if you are running on a different machine, replace `localhost:1338` with your original server's public -IP address (e.g. `12.34.56.78:1338`). Hivemind supports both ipv4 and ipv6 protocols and uses the same notation -as [gRPC](https://grpc.io/docs/languages/python/basics/#starting-the-server). +Copy the entire script (or notebook) and modify this line: -## Train the experts +```python +# old version: +dht = hivemind.DHT(start=True) -Now let's put these experts to work. Create a python console (or a jupyter) and run: +# new version: added initial_peers +dht = hivemind.DHT(initial_peers=['/ip4/127.0.0.1/tcp/COPY_FULL_ADDRESS_FROM_PEER1_OUTPUTS'], start=True) +``` +
+ Here's the full code of the second peer ```python import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms +from tqdm.auto import tqdm + import hivemind -dht = hivemind.DHT(initial_peers=["localhost:1338"], listen=False, start=True) -# note: listen=False means that your peer will operate in "client only" mode: -# this means that it can request other peers, but will not accept requests in return +# Create dataset and model, same as in the basic tutorial +# For this basic tutorial, we download only the training set +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + +trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) + +model = nn.Sequential(nn.Conv2d(3, 16, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(), + nn.Conv2d(16, 32, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(), + nn.Flatten(), nn.Linear(32 * 5 * 5, 10)) +opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + +# Create DHT: a decentralized key-value storage shared between peers +dht = hivemind.DHT(initial_peers=[COPY_FROM_ANOTHER_PEER_OUTPUTS], start=True) +print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()]) + +# Set up a decentralized optimizer that will average with peers in background +opt = hivemind.optim.DecentralizedOptimizer( + opt, # wrap the SGD optimizer defined above + dht, # use a DHT that is connected with other peers + average_parameters=True, # periodically average model weights in opt.step + average_gradients=False, # do not average accumulated gradients + prefix='my_cifar_run', # unique identifier of this collaborative run + target_group_size=16, # maximum concurrent peers for this run + verbose=True # print logs incessently +) -expert1, expert4 = hivemind.get_experts(dht, ["expert.1", "expert.4"]) -assert expert1 is not None and expert4 is not None, "server hasn't declared experts (yet?)" -``` +opt.averager.load_state_from_peers() -The experts (e.g. `expert1`) can be used as a pytorch module with autograd support: +# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created +with tqdm() as progressbar: + while True: + for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=256): + opt.zero_grad() + loss = F.cross_entropy(model(x_batch), y_batch) + loss.backward() + opt.step() -```python -dummy = torch.randn(3, 512) -out = expert1(dummy) # forward pass -out.sum().backward() # backward pass + progressbar.desc = f"loss = {loss.item():.3f}" + progressbar.update() ``` +
-When called, expert1 will submit a request to the corresponding server (which you created above) and return the output -tensor(s) or raise an exception. During backward, pytorch will submit the backward requests for the experts as they -appear in the computation graph. - -By default, the experts will automatically update their parameters with one step of SGD after each backward pass. This -allows you to quickly run training using both local and remote layers: -```python -# generate dummy data -x = torch.randn(3, 512) -y = 0.01 * x.sum(dim=-1, keepdim=True) +Instead of setting up a new DHT, the second peer will link up with the existing DHT node from the first peer. +If you run the second peer, you will see that both first and second peer will periodically report averaging parameters: -# local torch module -proj_out = torch.nn.Sequential( - torch.nn.Linear(512, 3) -) -opt = torch.optim.SGD(proj_out.parameters(), lr=0.01) - -for i in range(100): - prediction = proj_out(expert1(expert4(x))) - loss = torch.mean(abs(prediction - y)) - print(loss.item()) - opt.zero_grad() - loss.backward() - opt.step() +```shell +[...] Starting a new averaging round with current parameters. +[...] Finished averaging round in with 2 peers. ``` -Finally, you can create a Mixture-of-Experts layer over these experts: +This message means that the optimizer has averaged model parameters with another peer in background and applied them +during one of the calls to `opt.step()`. You can start more peers by replicating the same code as the second peer, +using either the first or second peer as `initial_peers`. +The only issue with this code is that each new peer starts with a different untrained network blends its un-trained +parameters with other peers, reseting their progress. You can see this effect as a spike increase in training loss +immediately after new peer joins training. To avoid this problem, the second peer can download the +current model/optimizer state from an existing peer right before it begins training on minibatches: ```python -import nest_asyncio - -nest_asyncio.apply() # asyncio patch for jupyter. for now, we recommend using MoE from console -dmoe = hivemind.RemoteMixtureOfExperts(in_features=512, uid_prefix="expert", grid_size=(5,), - dht=dht, k_best=2) - -out = dmoe(torch.randn(3, 512)) -out.sum().backward() +opt.averager.load_state_from_peers() ``` -The `dmoe` layer dynamically selects the right experts using a linear gating function. It will then dispatch parallel -forward (and backward) requests to those experts and collect results. You can find more details on how DMoE works in -Section 2.3 of the [paper](https://arxiv.org/abs/2002.04013) +Congrats, you've just started a pocket-sized experiment with decentralized deep learning! -Congratulations, you've made it through the basic tutorial. Give yourself a pat on the back :) +However, this is just the bare minimum of what hivemind can do. In [this example](https://github.com/learning-at-home/hivemind/tree/master/examples/albert), +we show how to use a more advanced version of DecentralizedOptimizer to collaboratively train a large Transformer over the internet. -More advanced tutorials are coming soon :) +If you want to learn more about each individual component, +- Learn how to use `hivemind.DHT` using this basic [DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html), +- Learn the underlying math behind DecentralizedOptimizer in + [(Li et al. 2020)](https://arxiv.org/abs/2005.00124) and [(Ryabinin et al. 2021)](https://arxiv.org/abs/2103.03239). +- Read about setting up Mixture-of-Experts training in [this guide](https://learning-at-home.readthedocs.io/en/latest/user/moe.html), + diff --git a/hivemind/averaging/training.py b/hivemind/averaging/training.py index 17fc4c0e1..b75ee0c7d 100644 --- a/hivemind/averaging/training.py +++ b/hivemind/averaging/training.py @@ -2,7 +2,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext from itertools import chain -from threading import Lock +from threading import Lock, Event from typing import Sequence, Dict, Iterator, Optional import torch @@ -50,6 +50,8 @@ def __init__( self.average_parameters, self.average_gradients = average_parameters, average_gradients self.step_executor = ThreadPoolExecutor(max_workers=1) self.lock_averager_step = Lock() + self.pending_updates_done = Event() + self.pending_updates_done.set() if initialize_optimizer: initialize_optimizer_state(opt) # note: this will run one optimizer step! @@ -75,6 +77,7 @@ def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs): local_tensors = list(self.local_tensors()) with self.lock_averager_step, torch.no_grad(): # fill averager's tensors with current local tensors + self.pending_updates_done.clear() with data_lock, self.get_tensors() as averaged_tensors: if use_old_local_tensors: old_local_tensors = tuple(x.cpu().float().clone() for x in local_tensors) @@ -83,11 +86,13 @@ def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs): ), "The number of optimized parameters should not change." for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors): averaged_tensor[...] = local_tensor.cpu().float() + self.pending_updates_done.set() # find a group and hopefully average tensors with peers, use batch sizes as weights gathered = super().step(**kwargs) if gathered is not None: # load averaged tensors back into model + self.pending_updates_done.clear() with data_lock, self.get_tensors() as averaged_tensors: if len(averaged_tensors) != len(local_tensors): raise RuntimeError("The number of optimized parameters should not change.") @@ -109,6 +114,7 @@ def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs): else: for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors): local_tensor.copy_(averaged_tensor, non_blocking=True) + self.pending_updates_done.set() self.local_step += 1 return gathered diff --git a/hivemind/moe/__init__.py b/hivemind/moe/__init__.py index 836d400dc..a488b2d14 100644 --- a/hivemind/moe/__init__.py +++ b/hivemind/moe/__init__.py @@ -1,2 +1,2 @@ from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts -from hivemind.moe.server import ExpertBackend, Server, register_expert_class +from hivemind.moe.server import ExpertBackend, Server, register_expert_class, get_experts, declare_experts diff --git a/hivemind/moe/server/__init__.py b/hivemind/moe/server/__init__.py index f318fd783..575fd1971 100644 --- a/hivemind/moe/server/__init__.py +++ b/hivemind/moe/server/__init__.py @@ -152,7 +152,8 @@ def create( dht = None else: dht = hivemind.DHT(initial_peers=initial_peers, start=True) - logger.info(f"Running DHT node on {dht.get_visible_maddrs()}, initial peers = {initial_peers}") + visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()] + logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") assert (expert_pattern is None and num_experts is None and expert_uids is not None) or ( num_experts is not None and expert_uids is None diff --git a/hivemind/optim/simple.py b/hivemind/optim/simple.py index 276dcd35e..688df44cd 100644 --- a/hivemind/optim/simple.py +++ b/hivemind/optim/simple.py @@ -66,6 +66,7 @@ def __init__( **kwargs, ) self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event() + self.lock_parameters.acquire() # this lock is only released when averager can modify tensors in background self.background_averaging_thread = Thread( name=f"{self.__class__.__name__}", @@ -77,13 +78,17 @@ def __init__( self.background_averaging_thread.start() def step(self, *args, **kwargs): - with self.lock_parameters: - loss = self.opt.step(*args, **kwargs) - - self.local_step += 1 - if self.local_step % self.averaging_step_period == 0: - self.update_event.set() - return loss + loss = self.opt.step(*args, **kwargs) + if self.lock_parameters.locked(): + self.lock_parameters.release() + try: + self.local_step += 1 + if self.local_step % self.averaging_step_period == 0: + self.update_event.set() + self.averager.pending_updates_done.wait() + return loss + finally: + self.lock_parameters.acquire() def zero_grad(self, *args, **kwargs): return self.opt.zero_grad(*args, **kwargs) diff --git a/setup.py b/setup.py index a097243b5..a9f7b0688 100644 --- a/setup.py +++ b/setup.py @@ -140,10 +140,10 @@ def run(self): version=version_string, cmdclass={"build_py": BuildPy, "develop": Develop}, description="Decentralized deep learning in PyTorch", - long_description="Decentralized deep learning in PyTorch. Built to train giant models on " - "thousands of volunteers across the world.", + long_description="Decentralized deep learning in PyTorch. Built to train models on thousands of volunteers " + "across the world.", author="Learning@home & contributors", - author_email="mryabinin0@gmail.com", + author_email="hivemind-team@hotmail.com", url="https://github.com/learning-at-home/hivemind", packages=find_packages(exclude=["tests"]), package_data={"hivemind": ["proto/*", "hivemind_cli/*"]}, diff --git a/tests/test_training.py b/tests/test_training.py index e3d1a330c..c82d383df 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -157,10 +157,13 @@ def test_decentralized_optimizer_step(): (param1.sum() + 300 * param2.sum()).backward() - opt1.step() - opt2.step() + for i in range(5): + time.sleep(0.1) + opt1.step() + opt2.step() + opt1.zero_grad() + opt2.zero_grad() - time.sleep(0.5) assert torch.allclose(param1, param2) reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300) assert torch.allclose(param1, torch.full_like(param1, reference)) @@ -193,13 +196,15 @@ def test_decentralized_optimizer_averaging(): verbose=True, ) - assert not torch.allclose(param1, param2) - + assert not torch.allclose(param1, param2, atol=1e-3, rtol=0) (param1.sum() + param2.sum()).backward() - opt1.step() - opt2.step() + for _ in range(100): + time.sleep(0.01) + opt1.step() + opt2.step() + opt1.zero_grad() + opt2.zero_grad() - time.sleep(0.5) - assert torch.allclose(param1, param2) - assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"]) + assert torch.allclose(param1, param2, atol=1e-3, rtol=0) + assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"], atol=1e-3, rtol=0)