Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #318 from gnes-ai/fix-flow-1
Browse files Browse the repository at this point in the history
fix(service): make service handler thread-safe
  • Loading branch information
mergify[bot] authored Oct 11, 2019
2 parents 9095bfa + 552fcdf commit f7e7791
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 42 deletions.
4 changes: 2 additions & 2 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def train(self, *args, **kwargs):
def dump(self, filename: str = None) -> None:
"""
Serialize the object to a binary file
:param filename: file path of the serialized file, if not given then `self.dump_full_path` is used
:param filename: file path of the serialized file, if not given then :py:attr:`dump_full_path` is used
"""
f = filename or self.dump_full_path
if not f:
Expand All @@ -260,7 +260,7 @@ def dump(self, filename: str = None) -> None:
def dump_yaml(self, filename: str = None) -> None:
"""
Serialize the object to a yaml file
:param filename: file path of the yaml file, if not given then `self.dump_yaml_path` is used
:param filename: file path of the yaml file, if not given then :py:attr:`dump_yaml_path` is used
"""
f = filename or self.yaml_full_path
if not f:
Expand Down
3 changes: 3 additions & 0 deletions gnes/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ def set_indexer_parser(parser=None):
if not parser:
parser = set_base_parser()
_set_sortable_service_parser(parser)
parser.add_argument('--as_response', type=ActionNoYes, default=True,
help='convert the message type from request to response after indexing. '
'turn it off if you want to chain other services after this index service.')

return parser

Expand Down
10 changes: 7 additions & 3 deletions gnes/flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ class Flow:
"""
GNES Flow: an intuitive way to build workflow for GNES.
You can use `.add()` then `.build()` to customize your own workflow.
You can use :py:meth:`.add()` then :py:meth:`.build()` to customize your own workflow.
For example:
.. highlight:: python
.. code-block:: python
from gnes.flow import Flow, Service as gfs
f = (Flow(check_version=False, route_table=True)
.add(gfs.Preprocessor, yaml_path='BasePreprocessor')
.add(gfs.Encoder, yaml_path='BaseEncoder')
Expand All @@ -76,9 +78,11 @@ class Flow:
flow.index()
...
You can also use the shortcuts, e.g. :py:meth:add_encoder , :py:meth:add_preprocessor
It is recommend to use flow in the context manner as showed above.
Note the different default copy behaviors in `.add()` and `.build()`:
`.add()` always copy the flow by default, whereas `.build()` modify the flow in place.
Note the different default copy behaviors in :py:meth:`.add()` and :py:meth:`.build()`:
:py:meth:`.add()` always copy the flow by default, whereas :py:meth:`.build()` modify the flow in place.
You can change this behavior by giving an argument `copy_flow=False`.
"""
Expand Down
5 changes: 3 additions & 2 deletions gnes/indexer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from functools import wraps
from typing import List, Any, Union, Callable, Tuple
from collections import defaultdict

import numpy as np

Expand All @@ -30,7 +30,8 @@ def __init__(self,
is_big_score_similar: bool = False,
*args, **kwargs):
"""
Base indexer, a valid indexer must implement `add` and `query` methods
Base indexer, a valid indexer must implement :py:meth:`add` and :py:meth:`query` methods
:type score_fn: advanced score function
:type normalize_fn: normalizing score function
:type is_big_score_similar: when set to true, then larger score means more similar
Expand Down
15 changes: 8 additions & 7 deletions gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def register(self, msg_type: Union[List, Tuple, type]):
def decorator(f):
if isinstance(msg_type, list) or isinstance(msg_type, tuple):
for m in msg_type:
self.routes[m] = f
self.routes[m] = f.__name__
else:
self.routes[msg_type] = f
self.routes[msg_type] = f.__name__
return f

return decorator
Expand All @@ -187,11 +187,12 @@ def register_hook(self, hook_type: Union[str, Tuple[str]], only_when_verbose: bo

def decorator(f):
if isinstance(hook_type, str) and hook_type in self.hooks:
self.hooks[hook_type].append((f, only_when_verbose))
self.hooks[hook_type].append((f.__name__, only_when_verbose))
return f
elif isinstance(hook_type, list) or isinstance(hook_type, tuple):
for h in set(hook_type):
if h in self.hooks:
self.hooks[h].append((f, only_when_verbose))
self.hooks[h].append((f.__name__, only_when_verbose))
else:
raise AttributeError('hook type: %s is not supported' % h)
return f
Expand Down Expand Up @@ -222,7 +223,7 @@ def call_hooks(self, msg: 'gnes_pb2.Message', hook_type: Union[str, Tuple[str]],
for fn, only_verbose in hooks:
if (only_verbose and self.service_context.args.verbose) or (not only_verbose):
try:
fn(self.service_context, msg, *args, **kwargs)
fn(msg, *args, **kwargs)
except Exception as ex:
self.logger.warning('hook %s throws an exception, '
'this wont affect the server but you may want to pay attention' % fn)
Expand All @@ -249,7 +250,7 @@ def get_default_fn(m_type):
fn = get_default_fn(type(msg))

self.logger.info('handling message with %s' % fn.__name__)
return fn(self.service_context, msg)
return fn(msg)

def call_routes_send_back(self, msg: 'gnes_pb2.Message', out_sock):
try:
Expand Down Expand Up @@ -334,7 +335,7 @@ def __init__(self, args):
check_version=self.args.check_version,
timeout=self.args.timeout,
squeeze_pb=self.args.squeeze_pb)
# self._override_handler()
self._override_handler()

def _override_handler(self):
# replace the function name by the function itself
Expand Down
32 changes: 21 additions & 11 deletions gnes/service/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,23 @@ def post_init(self):
def _handler_index(self, msg: 'gnes_pb2.Message'):
# print('tid: %s, model: %r, self._tmp_a: %r' % (threading.get_ident(), self._model, self._tmp_a))
# if self._tmp_a != threading.get_ident():
# print('tid: %s, tmp_a: %r !!! %r' % (threading.get_ident(), self._tmp_a, self._handler_index))
# print('!!! tid: %s, tmp_a: %r %r' % (threading.get_ident(), self._tmp_a, self._handler_index))
from ..indexer.base import BaseChunkIndexer, BaseDocIndexer
if isinstance(self._model, BaseChunkIndexer):
self._handler_chunk_index(msg)
is_changed = self._handler_chunk_index(msg)
elif isinstance(self._model, BaseDocIndexer):
self._handler_doc_index(msg)
is_changed = self._handler_doc_index(msg)
else:
raise ServiceError(
'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__)
msg.response.index.status = gnes_pb2.Response.SUCCESS
self.is_model_changed.set()

def _handler_chunk_index(self, msg: 'gnes_pb2.Message'):
if self.args.as_response:
msg.response.index.status = gnes_pb2.Response.SUCCESS

if is_changed:
self.is_model_changed.set()

def _handler_chunk_index(self, msg: 'gnes_pb2.Message') -> bool:
embed_info = []

for d in msg.request.index.docs:
Expand All @@ -59,13 +63,19 @@ def _handler_chunk_index(self, msg: 'gnes_pb2.Message'):
if embed_info:
vecs, doc_ids, offsets, weights = zip(*embed_info)
self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights)
return True
else:
self.logger.warning('chunks contain no embedded vectors, the indexer will do nothing')

def _handler_doc_index(self, msg: 'gnes_pb2.Message'):
self._model.add([d.doc_id for d in msg.request.index.docs],
[d for d in msg.request.index.docs],
[d.weight for d in msg.request.index.docs])
return False

def _handler_doc_index(self, msg: 'gnes_pb2.Message') -> bool:
if msg.request.index.docs:
self._model.add([d.doc_id for d in msg.request.index.docs],
[d for d in msg.request.index.docs],
[d.weight for d in msg.request.index.docs])
return True
else:
return False

def _put_result_into_message(self, results, msg: 'gnes_pb2.Message'):
msg.response.search.ClearField('topk_results')
Expand Down
21 changes: 12 additions & 9 deletions tests/test_gnes_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_flow5(self):
print(f._service_edges)
print(f.to_mermaid())

def _test_index_flow(self):
def _test_index_flow(self, backend):
for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]:
self.assertFalse(os.path.exists(k))

Expand All @@ -127,25 +127,28 @@ def _test_index_flow(self):
.add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter',
num_part=2, service_in=['vec_idx', 'doc_idx']))

with flow.build(backend='process') as f:
with flow.build(backend=backend) as f:
f.index(txt_file=self.test_file, batch_size=20)

for k in [self.indexer1_bin, self.indexer2_bin]:
self.assertTrue(os.path.exists(k))

def _test_query_flow(self):
def _test_query_flow(self, backend):
flow = (Flow(check_version=False, route_table=False)
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor')
.add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'))
.add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml'))
.add(gfs.Router, name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml'))
.add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml')))

with flow.build(backend='process') as f, open(self.test_file, encoding='utf8') as fp:
f.query(bytes_gen=[v.encode() for v in fp][:10])
with flow.build(backend=backend) as f, open(self.test_file, encoding='utf8') as fp:
f.query(bytes_gen=[v.encode() for v in fp][:3])

@unittest.SkipTest
# @unittest.SkipTest
def test_index_query_flow(self):
self._test_index_flow()
print('indexing finished')
self._test_query_flow()
self._test_index_flow('thread')
self._test_query_flow('thread')

def test_indexe_query_flow_proc(self):
self._test_index_flow('process')
self._test_query_flow('process')
9 changes: 1 addition & 8 deletions tests/yaml/flow-transformer.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
!PipelineEncoder
components:
- !PyTorchTransformers
parameters:
model_dir: $TORCH_TRANSFORMERS_MODEL
model_name: bert-base-uncased
- !PoolingEncoder
parameters:
pooling_strategy: REDUCE_MEAN
backend: torch
- !CharEmbeddingEncoder {}
gnes_config:
name: my_transformer # a customized name
is_trained: true # indicate the model has been trained
Expand Down

0 comments on commit f7e7791

Please sign in to comment.