From a3da05829c2756c42f45a41572a9a0f2217d9d6a Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Fri, 11 Oct 2019 10:34:26 +0800 Subject: [PATCH 1/4] fix(flow): fix flow unit test --- tests/test_gnes_flow.py | 2 +- tests/yaml/flow-transformer.yml | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/test_gnes_flow.py b/tests/test_gnes_flow.py index a5ec5f3f..b4c68711 100644 --- a/tests/test_gnes_flow.py +++ b/tests/test_gnes_flow.py @@ -144,7 +144,7 @@ def _test_query_flow(self): with flow.build(backend='process') as f, open(self.test_file, encoding='utf8') as fp: f.query(bytes_gen=[v.encode() for v in fp][:10]) - @unittest.SkipTest + # @unittest.SkipTest def test_index_query_flow(self): self._test_index_flow() print('indexing finished') diff --git a/tests/yaml/flow-transformer.yml b/tests/yaml/flow-transformer.yml index f32b1d56..ba5b17f9 100644 --- a/tests/yaml/flow-transformer.yml +++ b/tests/yaml/flow-transformer.yml @@ -1,13 +1,6 @@ !PipelineEncoder components: - - !PyTorchTransformers - parameters: - model_dir: $TORCH_TRANSFORMERS_MODEL - model_name: bert-base-uncased - - !PoolingEncoder - parameters: - pooling_strategy: REDUCE_MEAN - backend: torch + - !CharEmbeddingEncoder {} gnes_config: name: my_transformer # a customized name is_trained: true # indicate the model has been trained From 51581bf5ff08aad7b9ef768c44c7d64c0887a772 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Fri, 11 Oct 2019 11:15:56 +0800 Subject: [PATCH 2/4] fix(service): make service handler thread-safe --- gnes/service/base.py | 15 ++++++++------- gnes/service/indexer.py | 2 +- tests/test_gnes_flow.py | 19 +++++++++++-------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/gnes/service/base.py b/gnes/service/base.py index 5f80f8ec..eb30ddb3 100644 --- a/gnes/service/base.py +++ b/gnes/service/base.py @@ -170,9 +170,9 @@ def register(self, msg_type: Union[List, Tuple, type]): def decorator(f): if isinstance(msg_type, list) or isinstance(msg_type, tuple): for m in msg_type: - self.routes[m] = f + self.routes[m] = f.__name__ else: - self.routes[msg_type] = f + self.routes[msg_type] = f.__name__ return f return decorator @@ -187,11 +187,12 @@ def register_hook(self, hook_type: Union[str, Tuple[str]], only_when_verbose: bo def decorator(f): if isinstance(hook_type, str) and hook_type in self.hooks: - self.hooks[hook_type].append((f, only_when_verbose)) + self.hooks[hook_type].append((f.__name__, only_when_verbose)) + return f elif isinstance(hook_type, list) or isinstance(hook_type, tuple): for h in set(hook_type): if h in self.hooks: - self.hooks[h].append((f, only_when_verbose)) + self.hooks[h].append((f.__name__, only_when_verbose)) else: raise AttributeError('hook type: %s is not supported' % h) return f @@ -222,7 +223,7 @@ def call_hooks(self, msg: 'gnes_pb2.Message', hook_type: Union[str, Tuple[str]], for fn, only_verbose in hooks: if (only_verbose and self.service_context.args.verbose) or (not only_verbose): try: - fn(self.service_context, msg, *args, **kwargs) + fn(msg, *args, **kwargs) except Exception as ex: self.logger.warning('hook %s throws an exception, ' 'this wont affect the server but you may want to pay attention' % fn) @@ -249,7 +250,7 @@ def get_default_fn(m_type): fn = get_default_fn(type(msg)) self.logger.info('handling message with %s' % fn.__name__) - return fn(self.service_context, msg) + return fn(msg) def call_routes_send_back(self, msg: 'gnes_pb2.Message', out_sock): try: @@ -334,7 +335,7 @@ def __init__(self, args): check_version=self.args.check_version, timeout=self.args.timeout, squeeze_pb=self.args.squeeze_pb) - # self._override_handler() + self._override_handler() def _override_handler(self): # replace the function name by the function itself diff --git a/gnes/service/indexer.py b/gnes/service/indexer.py index 33037f88..f20edfc1 100644 --- a/gnes/service/indexer.py +++ b/gnes/service/indexer.py @@ -33,7 +33,7 @@ def post_init(self): def _handler_index(self, msg: 'gnes_pb2.Message'): # print('tid: %s, model: %r, self._tmp_a: %r' % (threading.get_ident(), self._model, self._tmp_a)) # if self._tmp_a != threading.get_ident(): - # print('tid: %s, tmp_a: %r !!! %r' % (threading.get_ident(), self._tmp_a, self._handler_index)) + # print('!!! tid: %s, tmp_a: %r %r' % (threading.get_ident(), self._tmp_a, self._handler_index)) from ..indexer.base import BaseChunkIndexer, BaseDocIndexer if isinstance(self._model, BaseChunkIndexer): self._handler_chunk_index(msg) diff --git a/tests/test_gnes_flow.py b/tests/test_gnes_flow.py index b4c68711..10b2429b 100644 --- a/tests/test_gnes_flow.py +++ b/tests/test_gnes_flow.py @@ -114,7 +114,7 @@ def test_flow5(self): print(f._service_edges) print(f.to_mermaid()) - def _test_index_flow(self): + def _test_index_flow(self, backend): for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]: self.assertFalse(os.path.exists(k)) @@ -127,13 +127,13 @@ def _test_index_flow(self): .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', num_part=2, service_in=['vec_idx', 'doc_idx'])) - with flow.build(backend='process') as f: + with flow.build(backend=backend) as f: f.index(txt_file=self.test_file, batch_size=20) for k in [self.indexer1_bin, self.indexer2_bin]: self.assertTrue(os.path.exists(k)) - def _test_query_flow(self): + def _test_query_flow(self, backend): flow = (Flow(check_version=False, route_table=False) .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor') .add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml')) @@ -141,11 +141,14 @@ def _test_query_flow(self): .add(gfs.Router, name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml')) .add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'))) - with flow.build(backend='process') as f, open(self.test_file, encoding='utf8') as fp: - f.query(bytes_gen=[v.encode() for v in fp][:10]) + with flow.build(backend=backend) as f, open(self.test_file, encoding='utf8') as fp: + f.query(bytes_gen=[v.encode() for v in fp][:3]) # @unittest.SkipTest def test_index_query_flow(self): - self._test_index_flow() - print('indexing finished') - self._test_query_flow() + self._test_index_flow('thread') + self._test_query_flow('thread') + + def test_indexe_query_flow_proc(self): + self._test_index_flow('process') + self._test_query_flow('process') From c880c9b0bfe7173b61ae6489669fea202f7200d0 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Fri, 11 Oct 2019 11:26:41 +0800 Subject: [PATCH 3/4] fix(service): make service handler thread-safe --- gnes/cli/parser.py | 3 +++ gnes/service/indexer.py | 30 ++++++++++++++++++++---------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/gnes/cli/parser.py b/gnes/cli/parser.py index 342206ed..14bb238f 100644 --- a/gnes/cli/parser.py +++ b/gnes/cli/parser.py @@ -288,6 +288,9 @@ def set_indexer_parser(parser=None): if not parser: parser = set_base_parser() _set_sortable_service_parser(parser) + parser.add_argument('--as_response', type=ActionNoYes, default=True, + help='convert the message type from request to response after indexing. ' + 'turn it off if you want to chain other services after this index service.') return parser diff --git a/gnes/service/indexer.py b/gnes/service/indexer.py index f20edfc1..e9e9f7e9 100644 --- a/gnes/service/indexer.py +++ b/gnes/service/indexer.py @@ -36,16 +36,20 @@ def _handler_index(self, msg: 'gnes_pb2.Message'): # print('!!! tid: %s, tmp_a: %r %r' % (threading.get_ident(), self._tmp_a, self._handler_index)) from ..indexer.base import BaseChunkIndexer, BaseDocIndexer if isinstance(self._model, BaseChunkIndexer): - self._handler_chunk_index(msg) + is_changed = self._handler_chunk_index(msg) elif isinstance(self._model, BaseDocIndexer): - self._handler_doc_index(msg) + is_changed = self._handler_doc_index(msg) else: raise ServiceError( 'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__) - msg.response.index.status = gnes_pb2.Response.SUCCESS - self.is_model_changed.set() - def _handler_chunk_index(self, msg: 'gnes_pb2.Message'): + if self.args.as_response: + msg.response.index.status = gnes_pb2.Response.SUCCESS + + if is_changed: + self.is_model_changed.set() + + def _handler_chunk_index(self, msg: 'gnes_pb2.Message') -> bool: embed_info = [] for d in msg.request.index.docs: @@ -59,13 +63,19 @@ def _handler_chunk_index(self, msg: 'gnes_pb2.Message'): if embed_info: vecs, doc_ids, offsets, weights = zip(*embed_info) self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights) + return True else: self.logger.warning('chunks contain no embedded vectors, the indexer will do nothing') - - def _handler_doc_index(self, msg: 'gnes_pb2.Message'): - self._model.add([d.doc_id for d in msg.request.index.docs], - [d for d in msg.request.index.docs], - [d.weight for d in msg.request.index.docs]) + return False + + def _handler_doc_index(self, msg: 'gnes_pb2.Message') -> bool: + if msg.request.index.docs: + self._model.add([d.doc_id for d in msg.request.index.docs], + [d for d in msg.request.index.docs], + [d.weight for d in msg.request.index.docs]) + return True + else: + return False def _put_result_into_message(self, results, msg: 'gnes_pb2.Message'): msg.response.search.ClearField('topk_results') From 552fcdfe9ffe627d134221000f6f59c6196e14a9 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Fri, 11 Oct 2019 11:44:13 +0800 Subject: [PATCH 4/4] feat(indexer-cli): add as_response switcher to indexer cli --- gnes/base/__init__.py | 4 ++-- gnes/flow/__init__.py | 10 +++++++--- gnes/indexer/base.py | 5 +++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/gnes/base/__init__.py b/gnes/base/__init__.py index 3c253294..617a7d7e 100644 --- a/gnes/base/__init__.py +++ b/gnes/base/__init__.py @@ -247,7 +247,7 @@ def train(self, *args, **kwargs): def dump(self, filename: str = None) -> None: """ Serialize the object to a binary file - :param filename: file path of the serialized file, if not given then `self.dump_full_path` is used + :param filename: file path of the serialized file, if not given then :py:attr:`dump_full_path` is used """ f = filename or self.dump_full_path if not f: @@ -260,7 +260,7 @@ def dump(self, filename: str = None) -> None: def dump_yaml(self, filename: str = None) -> None: """ Serialize the object to a yaml file - :param filename: file path of the yaml file, if not given then `self.dump_yaml_path` is used + :param filename: file path of the yaml file, if not given then :py:attr:`dump_yaml_path` is used """ f = filename or self.yaml_full_path if not f: diff --git a/gnes/flow/__init__.py b/gnes/flow/__init__.py index 27b3d480..057484d1 100644 --- a/gnes/flow/__init__.py +++ b/gnes/flow/__init__.py @@ -61,12 +61,14 @@ class Flow: """ GNES Flow: an intuitive way to build workflow for GNES. - You can use `.add()` then `.build()` to customize your own workflow. + You can use :py:meth:`.add()` then :py:meth:`.build()` to customize your own workflow. For example: .. highlight:: python .. code-block:: python + from gnes.flow import Flow, Service as gfs + f = (Flow(check_version=False, route_table=True) .add(gfs.Preprocessor, yaml_path='BasePreprocessor') .add(gfs.Encoder, yaml_path='BaseEncoder') @@ -76,9 +78,11 @@ class Flow: flow.index() ... + You can also use the shortcuts, e.g. :py:meth:add_encoder , :py:meth:add_preprocessor + It is recommend to use flow in the context manner as showed above. - Note the different default copy behaviors in `.add()` and `.build()`: - `.add()` always copy the flow by default, whereas `.build()` modify the flow in place. + Note the different default copy behaviors in :py:meth:`.add()` and :py:meth:`.build()`: + :py:meth:`.add()` always copy the flow by default, whereas :py:meth:`.build()` modify the flow in place. You can change this behavior by giving an argument `copy_flow=False`. """ diff --git a/gnes/indexer/base.py b/gnes/indexer/base.py index 332326e1..9748680a 100644 --- a/gnes/indexer/base.py +++ b/gnes/indexer/base.py @@ -12,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from functools import wraps from typing import List, Any, Union, Callable, Tuple -from collections import defaultdict import numpy as np @@ -30,7 +30,8 @@ def __init__(self, is_big_score_similar: bool = False, *args, **kwargs): """ - Base indexer, a valid indexer must implement `add` and `query` methods + Base indexer, a valid indexer must implement :py:meth:`add` and :py:meth:`query` methods + :type score_fn: advanced score function :type normalize_fn: normalizing score function :type is_big_score_similar: when set to true, then larger score means more similar