Skip to content

Commit

Permalink
Compile protobuf in setup.py (#85)
Browse files Browse the repository at this point in the history
* Compile protobuf in setup.py

* Update circleci pipelines

* Update RTD pipeline

* Refactor custom build_ext into install and develop
  • Loading branch information
mryab authored Aug 23, 2020
1 parent 53f5ab6 commit e7840e3
Show file tree
Hide file tree
Showing 14 changed files with 66 additions and 85 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ jobs:
steps:
- checkout
- python/load-cache
- run: sudo pip install codecov pytest tqdm scikit-learn
- run: pip install codecov pytest tqdm scikit-learn
- python/install-deps
- python/save-cache
- run:
command: sudo python setup.py develop
command: pip install -e .
name: setup
- run:
command: pytest ./tests
Expand Down
9 changes: 7 additions & 2 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,10 @@ version: 2
sphinx:
fail_on_warning: true

conda:
environment: docs/environment.yaml
python:
version: 3.7
install:
- requirements: requirements.txt
- requirements: docs/requirements.txt
- method: pip
path: .
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@


# -- Project information -----------------------------------------------------
sys.path.insert(0, '..')
src_path = '../hivemind'
project = 'hivemind'
copyright = '2020, Learning@home & contributors'
Expand Down
19 changes: 0 additions & 19 deletions docs/environment.yaml

This file was deleted.

2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
recommonmark
sphinx_rtd_theme
2 changes: 1 addition & 1 deletion hivemind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from hivemind.server import *
from hivemind.utils import *

__version__ = '0.7.1'
__version__ = '0.8.0'
3 changes: 2 additions & 1 deletion hivemind/client/expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import torch.nn as nn
from torch.autograd.function import once_differentiable

from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, runtime_pb2, runtime_grpc
from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor

DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert

Expand Down
9 changes: 5 additions & 4 deletions hivemind/client/moe.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations
import time

import asyncio
import time
from typing import Tuple, List, Optional, Awaitable, Set, Dict

import grpc.experimental.aio
import torch
import torch.nn as nn
from torch.autograd.function import once_differentiable
import grpc.experimental.aio

import hivemind
from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
from hivemind.utils import nested_map, nested_pack, nested_flatten, runtime_grpc, runtime_pb2, \
serialize_torch_tensor, deserialize_torch_tensor
from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
from hivemind.utils import nested_pack, nested_flatten, serialize_torch_tensor, deserialize_torch_tensor
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down
7 changes: 2 additions & 5 deletions hivemind/dht/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@

import asyncio
import heapq
import os
from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
from warnings import warn

import grpc
import grpc.experimental.aio

from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
from hivemind.utils import Endpoint, compile_grpc, get_logger, replace_port, get_port
from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
from hivemind.utils import Endpoint, get_logger, replace_port

logger = get_logger(__name__)

with open(os.path.join(os.path.dirname(__file__), 'dht.proto'), 'r') as f_proto:
dht_pb2, dht_grpc = compile_grpc(f_proto.read())


class DHTProtocol(dht_grpc.DHTServicer):
# fmt:off
Expand Down
File renamed without changes.
File renamed without changes.
3 changes: 2 additions & 1 deletion hivemind/server/connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import torch
import uvloop

from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
from hivemind.server.expert_backend import ExpertBackend
from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, runtime_pb2, runtime_grpc
from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint

logger = get_logger(__name__)

Expand Down
46 changes: 1 addition & 45 deletions hivemind/utils/grpc.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,11 @@
"""
Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
"""
import functools
import os
import sys
import tempfile
from argparse import Namespace
from typing import Tuple

import grpc_tools.protoc
import numpy as np
import torch


@functools.lru_cache(maxsize=None)
def compile_grpc(proto: str, *args: str) -> Tuple[Namespace, Namespace]:
"""
Compiles and loads grpc protocol defined by protobuf string
:param proto: protocol buffer code as a string, as in open('file.proto').read()
:param args: extra cli args for grpc_tools.protoc compiler, e.g. '-Imyincludepath'
:returns: messages, services protobuf
"""
base_include = grpc_tools.protoc.pkg_resources.resource_filename('grpc_tools', '_proto')

with tempfile.TemporaryDirectory(prefix='compile_grpc_') as build_dir:
proto_path = tempfile.mktemp(prefix='grpc_', suffix='.proto', dir=build_dir)
with open(proto_path, 'w') as fproto:
fproto.write(proto)

cli_args = (
grpc_tools.protoc.__file__, f"-I{base_include}",
f"--python_out={build_dir}", f"--grpc_python_out={build_dir}",
f"-I{build_dir}", *args, os.path.basename(proto_path))
code = grpc_tools.protoc._protoc_compiler.run_main([arg.encode() for arg in cli_args])
if code: # hint: if you get this error in jupyter, run in console for richer error message
raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}")

try:
sys.path.append(build_dir)
pb2_fname = os.path.basename(proto_path)[:-len('.proto')] + '_pb2'
messages, services = __import__(pb2_fname, fromlist=['*']), __import__(pb2_fname + '_grpc')
return messages, services
finally:
if sys.path.pop() != build_dir:
raise ImportError("Something changed sys.path while compile_grpc was in progress.")


with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
'server', 'connection_handler.proto')) as f_proto:
runtime_pb2, runtime_grpc = compile_grpc(f_proto.read())
from hivemind.proto import runtime_pb2


def serialize_torch_tensor(tensor: torch.Tensor) -> runtime_pb2.Tensor:
Expand Down
46 changes: 42 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,43 @@
from pkg_resources import parse_requirements
from setuptools import setup
import codecs
import re
import glob
import os
import re

import grpc_tools.protoc
from pkg_resources import parse_requirements
from setuptools import setup, find_packages
from setuptools.command.develop import develop
from setuptools.command.install import install


def proto_compile(output_path):
cli_args = ['grpc_tools.protoc',
'--proto_path=hivemind/proto', f'--python_out={output_path}',
f'--grpc_python_out={output_path}'] + glob.glob('hivemind/proto/*.proto')

code = grpc_tools.protoc.main(cli_args)
if code: # hint: if you get this error in jupyter, run in console for richer error message
raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}")
# Make pb2 imports in generated scripts relative
for script in glob.iglob(f'{output_path}/*.py'):
with open(script, 'r+') as file:
code = file.read()
file.seek(0)
file.write(re.sub(r'\n(import .+_pb2.*)', 'from . \\1', code))
file.truncate()


class ProtoCompileInstall(install):
def run(self):
proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto'))
super().run()


class ProtoCompileDevelop(develop):
def run(self):
proto_compile(os.path.join('hivemind', 'proto'))
super().run()


here = os.path.abspath(os.path.dirname(__file__))

Expand All @@ -17,12 +52,15 @@
setup(
name='hivemind',
version=version_string,
cmdclass={'install': ProtoCompileInstall, 'develop': ProtoCompileDevelop},
description='',
long_description='',
author='Learning@home authors',
author_email='[email protected]',
url="https://github.com/learning-at-home/hivemind",
packages=['hivemind'],
packages=find_packages(exclude=['tests']),
package_data={'hivemind': ['proto/*']},
include_package_data=True,
license='MIT',
install_requires=install_requires,
classifiers=[
Expand Down

0 comments on commit e7840e3

Please sign in to comment.