diff --git a/.circleci/config.yml b/.circleci/config.yml index c09aafe30..faa46c132 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,11 +9,11 @@ jobs: steps: - checkout - python/load-cache - - run: sudo pip install codecov pytest tqdm scikit-learn + - run: pip install codecov pytest tqdm scikit-learn - python/install-deps - python/save-cache - run: - command: sudo python setup.py develop + command: pip install -e . name: setup - run: command: pytest ./tests diff --git a/.readthedocs.yml b/.readthedocs.yml index 4471c4f84..7ba0b1a29 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -3,5 +3,10 @@ version: 2 sphinx: fail_on_warning: true -conda: - environment: docs/environment.yaml +python: + version: 3.7 + install: + - requirements: requirements.txt + - requirements: docs/requirements.txt + - method: pip + path: . diff --git a/docs/conf.py b/docs/conf.py index 5626f777a..bc69fe645 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,7 +22,6 @@ # -- Project information ----------------------------------------------------- -sys.path.insert(0, '..') src_path = '../hivemind' project = 'hivemind' copyright = '2020, Learning@home & contributors' diff --git a/docs/environment.yaml b/docs/environment.yaml deleted file mode 100644 index b4d5c9a16..000000000 --- a/docs/environment.yaml +++ /dev/null @@ -1,19 +0,0 @@ -channels: - - defaults - - anaconda - - pytorch - - conda-forge -dependencies: - - grpcio - - grpcio-tools - - numpy>=1.14 - - pytorch>=1.3.0 - - joblib>=0.13 - - pip - - pip: - - recommonmark - - sphinx_rtd_theme - - prefetch_generator>=1.0.1 - - uvloop>=0.14.0 - - umsgpack - diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000..2051f2a0c --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,2 @@ +recommonmark +sphinx_rtd_theme \ No newline at end of file diff --git a/hivemind/__init__.py b/hivemind/__init__.py index fd5aff145..56f0ee1f8 100644 --- a/hivemind/__init__.py +++ b/hivemind/__init__.py @@ -3,4 +3,4 @@ from hivemind.server import * from hivemind.utils import * -__version__ = '0.7.1' +__version__ = '0.8.0' diff --git a/hivemind/client/expert.py b/hivemind/client/expert.py index 9f5ececbe..45827b94d 100644 --- a/hivemind/client/expert.py +++ b/hivemind/client/expert.py @@ -8,8 +8,9 @@ import torch.nn as nn from torch.autograd.function import once_differentiable +from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint -from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, runtime_pb2, runtime_grpc +from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert diff --git a/hivemind/client/moe.py b/hivemind/client/moe.py index f5663833f..47ec2ed5d 100644 --- a/hivemind/client/moe.py +++ b/hivemind/client/moe.py @@ -1,17 +1,18 @@ from __future__ import annotations -import time + import asyncio +import time from typing import Tuple, List, Optional, Awaitable, Set, Dict +import grpc.experimental.aio import torch import torch.nn as nn from torch.autograd.function import once_differentiable -import grpc.experimental.aio import hivemind from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub -from hivemind.utils import nested_map, nested_pack, nested_flatten, runtime_grpc, runtime_pb2, \ - serialize_torch_tensor, deserialize_torch_tensor +from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc +from hivemind.utils import nested_pack, nested_flatten, serialize_torch_tensor, deserialize_torch_tensor from hivemind.utils.logging import get_logger logger = get_logger(__name__) diff --git a/hivemind/dht/protocol.py b/hivemind/dht/protocol.py index 260560d71..1bd956fe1 100644 --- a/hivemind/dht/protocol.py +++ b/hivemind/dht/protocol.py @@ -3,7 +3,6 @@ import asyncio import heapq -import os from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection from warnings import warn @@ -11,13 +10,11 @@ import grpc.experimental.aio from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time -from hivemind.utils import Endpoint, compile_grpc, get_logger, replace_port, get_port +from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc +from hivemind.utils import Endpoint, get_logger, replace_port logger = get_logger(__name__) -with open(os.path.join(os.path.dirname(__file__), 'dht.proto'), 'r') as f_proto: - dht_pb2, dht_grpc = compile_grpc(f_proto.read()) - class DHTProtocol(dht_grpc.DHTServicer): # fmt:off diff --git a/hivemind/dht/dht.proto b/hivemind/proto/dht.proto similarity index 100% rename from hivemind/dht/dht.proto rename to hivemind/proto/dht.proto diff --git a/hivemind/server/connection_handler.proto b/hivemind/proto/runtime.proto similarity index 100% rename from hivemind/server/connection_handler.proto rename to hivemind/proto/runtime.proto diff --git a/hivemind/server/connection_handler.py b/hivemind/server/connection_handler.py index b1991a25b..b55adcdb2 100644 --- a/hivemind/server/connection_handler.py +++ b/hivemind/server/connection_handler.py @@ -8,8 +8,9 @@ import torch import uvloop +from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc from hivemind.server.expert_backend import ExpertBackend -from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, runtime_pb2, runtime_grpc +from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint logger = get_logger(__name__) diff --git a/hivemind/utils/grpc.py b/hivemind/utils/grpc.py index 3ceccc010..73015d0a3 100644 --- a/hivemind/utils/grpc.py +++ b/hivemind/utils/grpc.py @@ -1,55 +1,11 @@ """ Utilities for running GRPC services: compile protobuf, patch legacy versions, etc """ -import functools -import os -import sys -import tempfile -from argparse import Namespace -from typing import Tuple -import grpc_tools.protoc import numpy as np import torch - -@functools.lru_cache(maxsize=None) -def compile_grpc(proto: str, *args: str) -> Tuple[Namespace, Namespace]: - """ - Compiles and loads grpc protocol defined by protobuf string - - :param proto: protocol buffer code as a string, as in open('file.proto').read() - :param args: extra cli args for grpc_tools.protoc compiler, e.g. '-Imyincludepath' - :returns: messages, services protobuf - """ - base_include = grpc_tools.protoc.pkg_resources.resource_filename('grpc_tools', '_proto') - - with tempfile.TemporaryDirectory(prefix='compile_grpc_') as build_dir: - proto_path = tempfile.mktemp(prefix='grpc_', suffix='.proto', dir=build_dir) - with open(proto_path, 'w') as fproto: - fproto.write(proto) - - cli_args = ( - grpc_tools.protoc.__file__, f"-I{base_include}", - f"--python_out={build_dir}", f"--grpc_python_out={build_dir}", - f"-I{build_dir}", *args, os.path.basename(proto_path)) - code = grpc_tools.protoc._protoc_compiler.run_main([arg.encode() for arg in cli_args]) - if code: # hint: if you get this error in jupyter, run in console for richer error message - raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}") - - try: - sys.path.append(build_dir) - pb2_fname = os.path.basename(proto_path)[:-len('.proto')] + '_pb2' - messages, services = __import__(pb2_fname, fromlist=['*']), __import__(pb2_fname + '_grpc') - return messages, services - finally: - if sys.path.pop() != build_dir: - raise ImportError("Something changed sys.path while compile_grpc was in progress.") - - -with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), - 'server', 'connection_handler.proto')) as f_proto: - runtime_pb2, runtime_grpc = compile_grpc(f_proto.read()) +from hivemind.proto import runtime_pb2 def serialize_torch_tensor(tensor: torch.Tensor) -> runtime_pb2.Tensor: diff --git a/setup.py b/setup.py index a5b0dd9b3..1d8d2eaa2 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,43 @@ -from pkg_resources import parse_requirements -from setuptools import setup import codecs -import re +import glob import os +import re + +import grpc_tools.protoc +from pkg_resources import parse_requirements +from setuptools import setup, find_packages +from setuptools.command.develop import develop +from setuptools.command.install import install + + +def proto_compile(output_path): + cli_args = ['grpc_tools.protoc', + '--proto_path=hivemind/proto', f'--python_out={output_path}', + f'--grpc_python_out={output_path}'] + glob.glob('hivemind/proto/*.proto') + + code = grpc_tools.protoc.main(cli_args) + if code: # hint: if you get this error in jupyter, run in console for richer error message + raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}") + # Make pb2 imports in generated scripts relative + for script in glob.iglob(f'{output_path}/*.py'): + with open(script, 'r+') as file: + code = file.read() + file.seek(0) + file.write(re.sub(r'\n(import .+_pb2.*)', 'from . \\1', code)) + file.truncate() + + +class ProtoCompileInstall(install): + def run(self): + proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto')) + super().run() + + +class ProtoCompileDevelop(develop): + def run(self): + proto_compile(os.path.join('hivemind', 'proto')) + super().run() + here = os.path.abspath(os.path.dirname(__file__)) @@ -17,12 +52,15 @@ setup( name='hivemind', version=version_string, + cmdclass={'install': ProtoCompileInstall, 'develop': ProtoCompileDevelop}, description='', long_description='', author='Learning@home authors', author_email='mryabinin@hse.ru', url="https://github.com/learning-at-home/hivemind", - packages=['hivemind'], + packages=find_packages(exclude=['tests']), + package_data={'hivemind': ['proto/*']}, + include_package_data=True, license='MIT', install_requires=install_requires, classifiers=[