diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..27e3f78c5 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/OpenNMTTokenizer"] + path = third_party/OpenNMTTokenizer + url = https://github.com/OpenNMT/Tokenizer.git diff --git a/.travis.yml b/.travis.yml index ec1257e48..25be5856f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,18 @@ language: python python: - "2.7" - "3.5" -install: +addons: + apt: + sources: + - george-edison55-precise-backports + - ubuntu-toolchain-r-test + packages: + - gcc-4.8 + - g++-4.8 + - cmake + - cmake-data + - libboost-python-dev +before_install: - pip install tensorflow==1.4.0 - pip install pyyaml - pip install nose2 @@ -14,6 +25,13 @@ install: pip install sphinx_rtd_theme pip install recommonmark fi +install: + - export CXX="g++-4.8" CC="gcc-4.8" + - mkdir build && cd build + - cmake .. + - make + - export PYTHONPATH="$PYTHONPATH:$PWD/third_party/OpenNMTTokenizer/bindings/python/" + - cd .. script: - nose2 - if [ "$TRAVIS_PYTHON_VERSION" == "3.5" ]; then pylint opennmt/ bin/; fi diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..6b165d40e --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,8 @@ +cmake_minimum_required(VERSION 3.1) + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_BUILD_TYPE Release) +set(LIB_ONLY ON) +set(WITH_PYTHON_BINDINGS ON) + +add_subdirectory(third_party/OpenNMTTokenizer) diff --git a/README.md b/README.md index 0847d567e..7b45728d8 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ OpenNMT-tf focuses on modularity and extensibility using standard TensorFlow mod * **hybrid encoder-decoder models**
e.g. self-attention encoder and RNN decoder or vice versa. * **multi-source training**
e.g. source text and Moses translation as inputs for machine translation. * **multiple input format**
text with support of mixed word/character embeddings or real vectors serialized in *TFRecord* files. +* **on-the-fly tokenization**
apply advanced tokenization dynamically during the training and detokenize the predictions during inference or evaluation. and all of the above can be used simultaneously to train novel and complex architectures. See the [predefined models](config/models) to discover how they are defined. @@ -76,6 +77,8 @@ python -m bin.main infer --config config/opennmt-defaults.yml config/data/toy-en **Note:** do not expect any good translation results with this toy example. Consider training on [larger parallel datasets](http://www.statmt.org/wmt16/translation-task.html) instead. +*For more advanced usages, see the [documentation](http://opennmt.net/OpenNMT-tf).* + ## Compatibility with {Lua,Py}Torch implementations OpenNMT-tf has been designed from scratch and compatibility with the {Lua,Py}Torch implementations in terms of usage, design, and features is not a priority. Please submit a feature request for any missing feature or behavior that you found useful in the {Lua,Py}Torch implementations. diff --git a/bin/build_vocab.py b/bin/build_vocab.py index e0d4e4142..c92050fde 100644 --- a/bin/build_vocab.py +++ b/bin/build_vocab.py @@ -6,8 +6,6 @@ from opennmt import tokenizers from opennmt import utils -from opennmt.utils.misc import get_classnames_in_module - def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -17,9 +15,6 @@ def main(): parser.add_argument( "--save_vocab", required=True, help="Output vocabulary file.") - parser.add_argument( - "--tokenizer", default="SpaceTokenizer", choices=get_classnames_in_module(tokenizers), - help="Tokenizer class name.") parser.add_argument( "--min_frequency", type=int, default=1, help="Minimum word frequency.") @@ -29,9 +24,10 @@ def main(): parser.add_argument( "--without_sequence_tokens", default=False, action="store_true", help="If set, do not add special sequence tokens (start, end) in the vocabulary.") + tokenizers.add_command_line_arguments(parser) args = parser.parse_args() - tokenizer = getattr(tokenizers, args.tokenizer)() + tokenizer = tokenizers.build_tokenizer(args) special_tokens = [constants.PADDING_TOKEN] if not args.without_sequence_tokens: diff --git a/bin/detokenize_text.py b/bin/detokenize_text.py new file mode 100644 index 000000000..9e768adfe --- /dev/null +++ b/bin/detokenize_text.py @@ -0,0 +1,22 @@ +"""Standalone script to detokenize a corpus.""" + +from __future__ import print_function + +import argparse + +from opennmt import tokenizers + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--delimiter", default=" ", + help="Token delimiter used in text serialization.") + tokenizers.add_command_line_arguments(parser) + args = parser.parse_args() + + tokenizer = tokenizers.build_tokenizer(args) + tokenizer.detokenize_stream(delimiter=args.delimiter) + +if __name__ == "__main__": + main() diff --git a/bin/tokenize_text.py b/bin/tokenize_text.py index 4192c2557..24dea3049 100644 --- a/bin/tokenize_text.py +++ b/bin/tokenize_text.py @@ -3,29 +3,20 @@ from __future__ import print_function import argparse -import sys from opennmt import tokenizers -from opennmt.utils.misc import get_classnames_in_module def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument( - "--tokenizer", default="SpaceTokenizer", choices=get_classnames_in_module(tokenizers), - help="Tokenizer class name.") parser.add_argument( "--delimiter", default=" ", help="Token delimiter for text serialization.") + tokenizers.add_command_line_arguments(parser) args = parser.parse_args() - tokenizer = getattr(tokenizers, args.tokenizer)() - - for line in sys.stdin: - line = line.strip() - tokens = tokenizer(line) - merged_tokens = args.delimiter.join(tokens) - print(merged_tokens) + tokenizer = tokenizers.build_tokenizer(args) + tokenizer.tokenize_stream(delimiter=args.delimiter) if __name__ == "__main__": main() diff --git a/config/sample.yml b/config/sample.yml index 63926bd64..69eaec5e4 100644 --- a/config/sample.yml +++ b/config/sample.yml @@ -77,7 +77,7 @@ train: # (optional) Save evaluation predictions in model_dir/eval/. save_eval_predictions: false # (optional) Evalutator or list of evaluators that are called on the saved evaluation predictions. - # Available evaluators: BLEU + # Available evaluators: BLEU, BLEU-detok external_evaluators: BLEU # (optional) The maximum length of feature sequences during training (default: None). maximum_features_length: 70 diff --git a/config/tokenization/aggressive.yml b/config/tokenization/aggressive.yml new file mode 100644 index 000000000..bb149a87f --- /dev/null +++ b/config/tokenization/aggressive.yml @@ -0,0 +1,2 @@ +mode: aggressive +joiner_annotate: true diff --git a/config/tokenization/sample.yml b/config/tokenization/sample.yml new file mode 100644 index 000000000..384d985d3 --- /dev/null +++ b/config/tokenization/sample.yml @@ -0,0 +1,12 @@ +# This is a sample tokenization configuration with all values set to their default. + +mode: conservative +bpe_model_path: "" +joiner: ■ +joiner_annotate: false +joiner_new: false +case_feature: false +segment_case: false +segment_numbers: false +segment_alphabet_change: false +segment_alphabet: [] diff --git a/docs/data.md b/docs/data.md index 8ddebe0f7..dbbaf83be 100644 --- a/docs/data.md +++ b/docs/data.md @@ -9,7 +9,7 @@ The format of the data files is defined by the `opennmt.inputters.Inputter` used All `opennmt.inputters.TextInputter`s expect a text file as input where: * sentences are separated by a **newline** -* tokens are separated by a **space** (unless a custom tokenizer is set) +* tokens are separated by a **space** (unless a custom tokenizer is set, see [Tokenization](tokenization.html)) For example: diff --git a/docs/index.rst b/docs/index.rst index 7d28237aa..bcf866783 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,6 +7,7 @@ Overview :maxdepth: 1 data.md + tokenization.md configuration.md training.md serving.md diff --git a/docs/package/opennmt.tokenizers.opennmt_tokenizer.rst b/docs/package/opennmt.tokenizers.opennmt_tokenizer.rst new file mode 100644 index 000000000..45450fd8d --- /dev/null +++ b/docs/package/opennmt.tokenizers.opennmt_tokenizer.rst @@ -0,0 +1,7 @@ +opennmt\.tokenizers\.opennmt\_tokenizer module +============================================== + +.. automodule:: opennmt.tokenizers.opennmt_tokenizer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/package/opennmt.tokenizers.rst b/docs/package/opennmt.tokenizers.rst index 4c5aa2432..1cdea2d38 100644 --- a/docs/package/opennmt.tokenizers.rst +++ b/docs/package/opennmt.tokenizers.rst @@ -11,5 +11,6 @@ Submodules .. toctree:: + opennmt.tokenizers.opennmt_tokenizer opennmt.tokenizers.tokenizer diff --git a/docs/tokenization.md b/docs/tokenization.md new file mode 100644 index 000000000..37b3da4af --- /dev/null +++ b/docs/tokenization.md @@ -0,0 +1,89 @@ +# Tokenization + +OpenNMT-tf can use the OpenNMT [Tokenizer](https://github.com/OpenNMT/Tokenizer) as a plugin to provide advanced tokenization behaviors. + +## Installation + +The following tools and packages are required: + +* C++11 compiler +* CMake +* Boost.Python + +On Ubuntu, these packages can be installed with `apt-get`: + +```bash +sudo apt-get install build-essential gcc cmake libboost-python-dev +``` + +1\. Fetch the Tokenizer plugin under OpenNMT-tf repository: + +```bash +git submodule update --init +``` + +2\. Compile the tokenizer plugin: + +```bash +mkdir build && cd build +cmake .. && make +cd .. +``` + +3\. Configure your environment for Python to find the newly generated package: + +```bash +export PYTHONPATH="$PYTHONPATH:$HOME/OpenNMT-tf/build/third_party/OpenNMTTokenizer/bindings/python/" +``` + +4\. Test the plugin: + +```bash +$ echo "Hello world!" | python -m bin.tokenize_text --tokenizer OpenNMTTokenizer +Hello world ! +``` + +## Usage + +YAML files are used to set the tokenizer options to ensure consistency during data preparation and training. See the sample file `config/tokenization/sample.yml`. + +Here is an example workflow: + +1\. Build the vocabularies with the custom tokenizer, e.g.: + +```bash +python -m bin.build_vocab --tokenizer OpenNMTTokenizer --tokenizer_config config/tokenization/aggressive.yml --size 50000 --save_vocab data/enfr/en-vocab.txt data/enfr/en-train.txt +python -m bin.build_vocab --tokenizer OpenNMTTokenizer --tokenizer_config config/tokenization/aggressive.yml --size 50000 --save_vocab data/enfr/fr-vocab.txt data/enfr/fr-train.txt +``` + +*The text files are only given as examples and are not part of the repository.* + +2\. Update your model's `TextInputter`s to use the custom tokenizer, e.g.: + +```python +return onmt.models.SequenceToSequence( + source_inputter=onmt.inputters.WordEmbedder( + vocabulary_file_key="source_words_vocabulary", + embedding_size=512, + tokenizer=onmt.tokenizers.OpenNMTTokenizer( + configuration_file_or_key="source_tokenizer_config")), + target_inputter=onmt.inputters.WordEmbedder( + vocabulary_file_key="target_words_vocabulary", + embedding_size=512, + tokenizer=onmt.tokenizers.OpenNMTTokenizer( + configuration_file_or_key="target_tokenizer_config")), + ...) +``` + +3\. Reference the tokenizer configurations in the data configuration, e.g.: + +```yaml +data: + source_tokenizer_config: config/tokenization/aggressive.yml + target_tokenizer_config: config/tokenization/aggressive.yml +``` + +## Notes + +* As of now, tokenizers are not part of the exported graph. +* Predictions saved during inference or evaluation are detokenized. Consider using the "BLEU-detok" external evaluator that applies a simple word level tokenization before computing the BLEU score. diff --git a/opennmt/inputters/text_inputter.py b/opennmt/inputters/text_inputter.py index 5e2e78bd8..a40b4c423 100644 --- a/opennmt/inputters/text_inputter.py +++ b/opennmt/inputters/text_inputter.py @@ -224,7 +224,7 @@ def _process(self, data): if "tokens" not in data: text = data["raw"] - tokens = self.tokenizer(text) + tokens = self.tokenizer.tokenize(text) length = tf.shape(tokens)[0] data = self.set_data_field(data, "tokens", tokens, padded_shape=[None], volatile=True) diff --git a/opennmt/models/sequence_to_sequence.py b/opennmt/models/sequence_to_sequence.py index 368d01b5e..a7f51fcb5 100644 --- a/opennmt/models/sequence_to_sequence.py +++ b/opennmt/models/sequence_to_sequence.py @@ -231,5 +231,5 @@ def print_prediction(self, prediction, params=None, stream=None): for i in range(n_best): tokens = prediction["tokens"][i][:prediction["length"][i] - 1] # Ignore . - sentence = b" ".join(tokens) - print_bytes(sentence, stream=stream) + sentence = self.target_inputter.tokenizer.detokenize(tokens) + print_bytes(tf.compat.as_bytes(sentence), stream=stream) diff --git a/opennmt/tests/tokenizer_test.py b/opennmt/tests/tokenizer_test.py index da3875ee6..59c74e48e 100644 --- a/opennmt/tests/tokenizer_test.py +++ b/opennmt/tests/tokenizer_test.py @@ -10,25 +10,62 @@ class TokenizerTest(tf.test.TestCase): def _testTokenizerOnTensor(self, tokenizer, text, ref_tokens): ref_tokens = [tf.compat.as_bytes(token) for token in ref_tokens] text = tf.constant(text) - tokens = tokenizer(text) + tokens = tokenizer.tokenize(text) with self.test_session() as sess: tokens = sess.run(tokens) self.assertAllEqual(ref_tokens, tokens) def _testTokenizerOnString(self, tokenizer, text, ref_tokens): ref_tokens = [tf.compat.as_text(token) for token in ref_tokens] - tokens = tokenizer(text) + tokens = tokenizer.tokenize(text) self.assertAllEqual(ref_tokens, tokens) def _testTokenizer(self, tokenizer, text, ref_tokens): self._testTokenizerOnTensor(tokenizer, text, ref_tokens) self._testTokenizerOnString(tokenizer, text, ref_tokens) + def _testDetokenizerOnTensor(self, tokenizer, tokens, ref_text): + ref_text = tf.compat.as_bytes(ref_text) + tokens = tf.constant(tokens) + text = tokenizer.detokenize(tokens) + with self.test_session() as sess: + text = sess.run(text) + self.assertEqual(ref_text, text) + + def _testDetokenizerOnBatchTensor(self, tokenizer, tokens, ref_text): + ref_text = [tf.compat.as_bytes(t) for t in ref_text] + sequence_length = [len(x) for x in tokens] + max_length = max(sequence_length) + tokens = [tok + [""] * (max_length - len(tok)) for tok in tokens] + tokens = tf.constant(tokens) + sequence_length = tf.constant(sequence_length) + text = tokenizer.detokenize(tokens, sequence_length=sequence_length) + with self.test_session() as sess: + text = sess.run(text) + self.assertAllEqual(ref_text, text) + + def _testDetokenizerOnString(self, tokenizer, tokens, ref_text): + tokens = [tf.compat.as_text(token) for token in tokens] + ref_text = tf.compat.as_text(ref_text) + text = tokenizer.detokenize(tokens) + self.assertAllEqual(ref_text, text) + + def _testDetokenizer(self, tokenizer, tokens, ref_text): + self._testDetokenizerOnBatchTensor(tokenizer, tokens, ref_text) + for tok, ref in zip(tokens, ref_text): + self._testDetokenizerOnTensor(tokenizer, tok, ref) + self._testDetokenizerOnString(tokenizer, tok, ref) + def testSpaceTokenizer(self): self._testTokenizer(SpaceTokenizer(), "Hello world !", ["Hello", "world", "!"]) + self._testDetokenizer( + SpaceTokenizer(), + [["Hello", "world", "!"], ["Test"], ["My", "name"]], + ["Hello world !", "Test", "My name"]) def testCharacterTokenizer(self): - self._testTokenizer(CharacterTokenizer(), "a b", ["a", " ", "b"]) + self._testTokenizer(CharacterTokenizer(), "a b", ["a", "▁", "b"]) + self._testDetokenizer(CharacterTokenizer(), [["a", "▁", "b"]], ["a b"]) self._testTokenizer(CharacterTokenizer(), "你好,世界!", ["你", "好", ",", "世", "界", "!"]) diff --git a/opennmt/tokenizers/__init__.py b/opennmt/tokenizers/__init__.py index 45d0f25d5..ab98e9213 100644 --- a/opennmt/tokenizers/__init__.py +++ b/opennmt/tokenizers/__init__.py @@ -3,4 +3,33 @@ Tokenizers can work on string ``tf.Tensor`` as in-graph transformation. """ +import sys +import inspect + +try: + import pyonmttok + from opennmt.tokenizers.opennmt_tokenizer import OpenNMTTokenizer +except ImportError: + pass + from opennmt.tokenizers.tokenizer import SpaceTokenizer, CharacterTokenizer + +def add_command_line_arguments(parser): + """Adds command line arguments to select the tokenizer.""" + choices = [] + module = sys.modules[__name__] + for symbol in dir(module): + if inspect.isclass(getattr(module, symbol)): + choices.append(symbol) + + parser.add_argument( + "--tokenizer", default="SpaceTokenizer", choices=choices, + help="Tokenizer class name.") + parser.add_argument( + "--tokenizer_config", default=None, + help="Tokenization configuration file.") + +def build_tokenizer(args): + """Returns a new tokenizer based on command line arguments.""" + module = sys.modules[__name__] + return getattr(module, args.tokenizer)(configuration_file_or_key=args.tokenizer_config) diff --git a/opennmt/tokenizers/opennmt_tokenizer.py b/opennmt/tokenizers/opennmt_tokenizer.py new file mode 100644 index 000000000..5aced32d7 --- /dev/null +++ b/opennmt/tokenizers/opennmt_tokenizer.py @@ -0,0 +1,60 @@ +"""Define the OpenNMT tokenizer.""" + +import six + +import pyonmttok + +import tensorflow as tf + +from opennmt.tokenizers.tokenizer import Tokenizer + + +def create_tokenizer(config): + """Creates a new OpenNMT tokenizer. + + Args: + config: A dictionary of tokenization options. + + Returns: + A ``pyonmttok.Tokenizer``. + """ + def _set(kwargs, key): + if key in config: + value = config[key] + if isinstance(value, six.string_types): + value = tf.compat.as_bytes(value) + kwargs[key] = value + + kwargs = {} + _set(kwargs, "bpe_model_path") + _set(kwargs, "joiner") + _set(kwargs, "joiner_annotate") + _set(kwargs, "joiner_new") + _set(kwargs, "case_feature") + _set(kwargs, "segment_case") + _set(kwargs, "segment_numbers") + _set(kwargs, "segment_alphabet_change") + _set(kwargs, "segment_alphabet") + + return pyonmttok.Tokenizer(config.get("mode", "conservative"), **kwargs) + + +class OpenNMTTokenizer(Tokenizer): + """Uses the OpenNMT tokenizer.""" + + def __init__(self, configuration_file_or_key=None): + super(OpenNMTTokenizer, self).__init__(configuration_file_or_key=configuration_file_or_key) + self._tokenizer = None + + def _tokenize_string(self, text): + if self._tokenizer is None: + self._tokenizer = create_tokenizer(self._config) + text = tf.compat.as_bytes(text) + tokens, _ = self._tokenizer.tokenize(text) + return tokens + + def _detokenize_string(self, tokens): + if self._tokenizer is None: + self._tokenizer = create_tokenizer(self._config) + tokens = [tf.compat.as_bytes(token) for token in tokens] + return self._tokenizer.detokenize(tokens) diff --git a/opennmt/tokenizers/tokenizer.py b/opennmt/tokenizers/tokenizer.py index ebc571bee..cb96c15b1 100644 --- a/opennmt/tokenizers/tokenizer.py +++ b/opennmt/tokenizers/tokenizer.py @@ -1,16 +1,84 @@ +# -*- coding: utf-8 -*- + """Define base tokenizers.""" +import sys +import os import abc import six +import yaml import tensorflow as tf +from opennmt.utils.misc import print_bytes + @six.add_metaclass(abc.ABCMeta) class Tokenizer(object): """Base class for tokenizers.""" - def __call__(self, text): + def __init__(self, configuration_file_or_key=None): + """Initializes the tokenizer. + + Args: + configuration_file_or_key: The YAML configuration file or a the key to + the YAML configuration file. + """ + self._config = {} + if configuration_file_or_key is not None and os.path.isfile(configuration_file_or_key): + configuration_file = configuration_file_or_key + with open(configuration_file) as conf_file: + self._config = yaml.load(conf_file) + self._configuration_file_key = None + else: + self._configuration_file_key = configuration_file_or_key + + def initialize(self, metadata): + """Initializes the tokenizer (e.g. load BPE models). + + Any external assets should be registered in the standard assets collection: + + .. code-block:: python + + tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, filename) + + Args: + metadata: A dictionary containing additional metadata set + by the user. + """ + if self._configuration_file_key is not None: + configuration_file = metadata[self._configuration_file_key] + with open(configuration_file) as conf_file: + self._config = yaml.load(conf_file) + + def tokenize_stream(self, input_stream=sys.stdin, output_stream=sys.stdout, delimiter=" "): + """Tokenizes a stream of sentences. + + Args: + input_stream: The input stream. + output_stream: The output stream. + delimiter: The token delimiter to use for text serialization. + """ + for line in input_stream: + line = line.strip() + tokens = self.tokenize(line) + merged_tokens = delimiter.join(tokens) + print_bytes(tf.compat.as_bytes(merged_tokens), stream=output_stream) + + def detokenize_stream(self, input_stream=sys.stdin, output_stream=sys.stdout, delimiter=" "): + """Detokenizes a stream of sentences. + + Args: + input_stream: The input stream. + output_stream: The output stream. + delimiter: The token delimiter used for text serialization. + """ + for line in input_stream: + tokens = line.strip().split(delimiter) + string = self.detokenize(tokens) + print_bytes(tf.compat.as_bytes(string), stream=output_stream) + + def tokenize(self, text): """Tokenizes text. Args: @@ -19,27 +87,61 @@ def __call__(self, text): Returns: A 1-D string ``tf.Tensor`` if :obj:`text` is a ``tf.Tensor`` or a list of Python unicode strings otherwise. + + Raises: + ValueError: if the rank of :obj:`text` is greater than 0. """ if tf.contrib.framework.is_tensor(text): - return self._tokenize_tensor(text) + rank = len(text.get_shape().as_list()) + if rank == 0: + return self._tokenize_tensor(text) + else: + raise ValueError("Unsupported tensor rank for tokenization: {}".format(rank)) else: text = tf.compat.as_text(text) return self._tokenize_string(text) - def initialize(self, metadata): - """Initializes the tokenizer (e.g. load BPE models). + def detokenize(self, tokens, sequence_length=None): + """Detokenizes tokens. - Any external assets should be registered in the standard assets collection: + The Tensor version supports batches of tokens. - .. code-block:: python + Args: + tokens: The tokens as a 1-D or 2-D ``tf.Tensor`` or list of Python + strings. + sequence_length: The length of each sequence. Required if :obj:`tokens` + is a ``tf.Tensor``. - tf.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) + Returns: + A 0-D or 1-D string ``tf.Tensor`` if :obj:`tokens` is a ``tf.Tensor`` or a + Python unicode strings otherwise. - Args: - metadata: A dictionary containing additional metadata set - by the user. + Raises: + ValueError: if the rank of :obj:`tokens` is greater than 2. + ValueError: if :obj:`tokens` is a 2-D ``tf.Tensor`` and + :obj:`sequence_length` is not set. """ - pass + if tf.contrib.framework.is_tensor(tokens): + rank = len(tokens.get_shape().as_list()) + if rank == 1: + return self._detokenize_tensor(tokens) + elif rank == 2: + if sequence_length is None: + raise ValueError("sequence_length is required for Tensor detokenization") + batch_size = tf.shape(tokens)[0] + array = tf.TensorArray(tf.string, size=batch_size, dynamic_size=False) + _, array = tf.while_loop( + lambda i, _: i < batch_size, + lambda i, a: ( + i + 1, a.write(i, self._detokenize_tensor(tokens[i, :sequence_length[i]]))), + (tf.constant(0), array), + back_prop=False) + return array.stack() + else: + raise ValueError("Unsupported tensor rank for detokenization: {}".format(rank)) + else: + tokens = [tf.compat.as_text(token) for token in tokens] + return self._detokenize_string(tokens) def _tokenize_tensor(self, text): """Tokenizes a tensor. @@ -54,10 +156,24 @@ def _tokenize_tensor(self, text): A 1-D string ``tf.Tensor``. """ text = tf.py_func( - lambda x: tf.compat.as_bytes("\0".join(self(x))), [text], tf.string) + lambda x: tf.compat.as_bytes("\0".join(self.tokenize(x))), [text], tf.string) tokens = tf.string_split([text], delimiter="\0").values return tokens + def _detokenize_tensor(self, tokens): + """Detokenizes tokens. + + When not overriden, this default implementation uses a ``tf.py_func`` + operation to call the string-based detokenization. + + Args: + tokens: A 1-D ``tf.Tensor``. + + Returns: + A 0-D string ``tf.Tensor``. + """ + return tf.py_func(self.detokenize, [tokens], tf.string) + @abc.abstractmethod def _tokenize_string(self, text): """Tokenizes a Python unicode string. @@ -72,19 +188,40 @@ def _tokenize_string(self, text): """ raise NotImplementedError() + @abc.abstractmethod + def _detokenize_string(self, tokens): + """Detokenizes tokens. + + Args: + tokens: A list of Python unicode strings. + + Returns: + A unicode Python string. + """ + raise NotImplementedError() + class SpaceTokenizer(Tokenizer): """A tokenizer that splits on spaces.""" def _tokenize_tensor(self, text): - return tf.string_split([text]).values + return tf.string_split([text], delimiter=" ").values + + def _detokenize_tensor(self, tokens): + return tf.foldl(lambda a, x: a + " " + x, tokens, back_prop=False) def _tokenize_string(self, text): return text.split() + def _detokenize_string(self, tokens): + return " ".join(tokens) + class CharacterTokenizer(Tokenizer): """A tokenizer that splits unicode characters.""" def _tokenize_string(self, text): - return list(text) + return list(text.replace(" ", u"▁")) + + def _detokenize_string(self, tokens): + return "".join(tokens).replace(u"▁", " ") diff --git a/opennmt/utils/evaluator.py b/opennmt/utils/evaluator.py index 183be6238..1d5c1f9c9 100644 --- a/opennmt/utils/evaluator.py +++ b/opennmt/utils/evaluator.py @@ -9,6 +9,13 @@ import tensorflow as tf from tensorflow.python.summary.writer.writer_cache import FileWriterCache as SummaryWriterCache +from opennmt import tokenizers + + +def _word_level_tokenization(input_filename, output_filename): + tokenizer = tokenizers.OpenNMTTokenizer() + with open(input_filename, "rb") as input_file, open(output_filename, "wb") as output_file: + tokenizer.tokenize_stream(input_stream=input_file, output_stream=output_file) @six.add_metaclass(abc.ABCMeta) @@ -82,6 +89,26 @@ def score(self, labels_file, predictions_path): return None +class BLEUDetokEvaluator(BLEUEvaluator): + """Evaluator applying a simple tokenization before calling multi-bleu.perl.""" + + def __init__(self, labels_file=None, output_dir=None): + if not hasattr(tokenizers, "OpenNMTTokenizer"): + raise RuntimeError("The BLEU-detok evaluator only works when the OpenNMT tokenizer " + "is available. Please re-check its installation.") + super(BLEUDetokEvaluator, self).__init__(labels_file=labels_file, output_dir=output_dir) + + def name(self): + return "BLEU-detok" + + def score(self, labels_file, predictions_path): + tok_labels_file = labels_file + ".light_tok" + tok_predictions_path = predictions_path + ".light_tok" + _word_level_tokenization(labels_file, tok_labels_file) + _word_level_tokenization(predictions_path, tok_predictions_path) + return super(BLEUDetokEvaluator, self).score(tok_labels_file, tok_predictions_path) + + def external_evaluation_fn(evaluators_name, labels_file, output_dir=None): """Returns a callable to be used in :class:`opennmt.utils.hooks.SaveEvaluationPredictionHook` that calls one or @@ -110,6 +137,8 @@ def external_evaluation_fn(evaluators_name, labels_file, output_dir=None): name = name.lower() if name == "bleu": evaluator = BLEUEvaluator(labels_file=labels_file, output_dir=output_dir) + elif name == "bleu-detok": + evaluator = BLEUDetokEvaluator(labels_file=labels_file, output_dir=output_dir) else: raise ValueError("No evaluator associated with the name: {}".format(name)) evaluators.append(evaluator) diff --git a/opennmt/utils/misc.py b/opennmt/utils/misc.py index 66c2db197..ca9e49b60 100644 --- a/opennmt/utils/misc.py +++ b/opennmt/utils/misc.py @@ -3,7 +3,6 @@ from __future__ import print_function import sys -import inspect import tensorflow as tf @@ -43,14 +42,6 @@ def count_lines(filename): pass return i + 1 -def get_classnames_in_module(module): - """Returns a list of classnames exposed by a module.""" - names = [] - for symbol in dir(module): - if inspect.isclass(getattr(module, symbol)): - names.append(symbol) - return names - def count_parameters(): """Returns the total number of trainable parameters.""" total = 0 diff --git a/opennmt/utils/vocab.py b/opennmt/utils/vocab.py index a32c4ea64..e9f20593e 100644 --- a/opennmt/utils/vocab.py +++ b/opennmt/utils/vocab.py @@ -44,7 +44,7 @@ def add_from_text(self, filename, tokenizer=None): for line in text: line = tf.compat.as_text(line.strip()) if tokenizer: - tokens = tokenizer(line) + tokens = tokenizer.tokenize(line) else: tokens = line.split() for token in tokens: diff --git a/third_party/OpenNMTTokenizer b/third_party/OpenNMTTokenizer new file mode 160000 index 000000000..60ac5efd0 --- /dev/null +++ b/third_party/OpenNMTTokenizer @@ -0,0 +1 @@ +Subproject commit 60ac5efd0db175ad2d736f1086e434680dc068bc diff --git a/third_party/learn_bpe.py b/third_party/learn_bpe.py new file mode 100644 index 000000000..2d1da3848 --- /dev/null +++ b/third_party/learn_bpe.py @@ -0,0 +1,272 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# Author: Rico Sennrich + +# The MIT License (MIT) + +# Copyright (c) 2015 University of Edinburgh + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text. +Unlike the original BPE, it does not compress the plain text, but can be used to reduce the vocabulary +of a text to a configurable number of symbols, with only a small increase in the number of tokens. + +Reference: +Rico Sennrich, Barry Haddow and Alexandra Birch (2016). Neural Machine Translation of Rare Words with Subword Units. +Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany. +""" + +from __future__ import unicode_literals + +import sys +import codecs +import re +import copy +import argparse +from collections import defaultdict, Counter + +# hack for python2/3 compatibility +from io import open +argparse.open = open + +def create_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="learn BPE-based word segmentation") + + parser.add_argument( + '--input', '-i', type=argparse.FileType('r'), default=sys.stdin, + metavar='PATH', + help="Input text (default: standard input).") + + parser.add_argument( + '--output', '-o', type=argparse.FileType('w'), default=sys.stdout, + metavar='PATH', + help="Output file for BPE codes (default: standard output)") + parser.add_argument( + '--symbols', '-s', type=int, default=10000, + help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))") + parser.add_argument( + '--min-frequency', type=int, default=2, metavar='FREQ', + help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))') + parser.add_argument('--dict-input', action="store_true", + help="If set, input file is interpreted as a dictionary where each line contains a word-count pair") + parser.add_argument( + '--verbose', '-v', action="store_true", + help="verbose mode.") + + return parser + +def get_vocabulary(fobj, is_dict=False): + """Read text and return dictionary that encodes vocabulary + """ + vocab = Counter() + for line in fobj: + if is_dict: + word, count = line.strip().split() + vocab[word] = int(count) + else: + for word in line.split(): + vocab[word] += 1 + return vocab + +def update_pair_statistics(pair, changed, stats, indices): + """Minimally update the indices and frequency of symbol pairs + + if we merge a pair of symbols, only pairs that overlap with occurrences + of this pair are affected, and need to be updated. + """ + stats[pair] = 0 + indices[pair] = defaultdict(int) + first, second = pair + new_pair = first+second + for j, word, old_word, freq in changed: + + # find all instances of pair, and update frequency/indices around it + i = 0 + while True: + # find first symbol + try: + i = old_word.index(first, i) + except ValueError: + break + # if first symbol is followed by second symbol, we've found an occurrence of pair (old_word[i:i+2]) + if i < len(old_word)-1 and old_word[i+1] == second: + # assuming a symbol sequence "A B C", if "B C" is merged, reduce the frequency of "A B" + if i: + prev = old_word[i-1:i+1] + stats[prev] -= freq + indices[prev][j] -= 1 + if i < len(old_word)-2: + # assuming a symbol sequence "A B C B", if "B C" is merged, reduce the frequency of "C B". + # however, skip this if the sequence is A B C B C, because the frequency of "C B" will be reduced by the previous code block + if old_word[i+2] != first or i >= len(old_word)-3 or old_word[i+3] != second: + nex = old_word[i+1:i+3] + stats[nex] -= freq + indices[nex][j] -= 1 + i += 2 + else: + i += 1 + + i = 0 + while True: + try: + # find new pair + i = word.index(new_pair, i) + except ValueError: + break + # assuming a symbol sequence "A BC D", if "B C" is merged, increase the frequency of "A BC" + if i: + prev = word[i-1:i+1] + stats[prev] += freq + indices[prev][j] += 1 + # assuming a symbol sequence "A BC B", if "B C" is merged, increase the frequency of "BC B" + # however, if the sequence is A BC BC, skip this step because the count of "BC BC" will be incremented by the previous code block + if i < len(word)-1 and word[i+1] != new_pair: + nex = word[i:i+2] + stats[nex] += freq + indices[nex][j] += 1 + i += 1 + + +def get_pair_statistics(vocab): + """Count frequency of all symbol pairs, and create index""" + + # data structure of pair frequencies + stats = defaultdict(int) + + #index from pairs to words + indices = defaultdict(lambda: defaultdict(int)) + + for i, (word, freq) in enumerate(vocab): + prev_char = word[0] + for char in word[1:]: + stats[prev_char, char] += freq + indices[prev_char, char][i] += 1 + prev_char = char + + return stats, indices + + +def replace_pair(pair, vocab, indices): + """Replace all occurrences of a symbol pair ('A', 'B') with a new symbol 'AB'""" + first, second = pair + pair_str = ''.join(pair) + pair_str = pair_str.replace('\\','\\\\') + changes = [] + pattern = re.compile(r'(?'); + # version numbering allows bckward compatibility + outfile.write('#version: 0.2\n') + + vocab = get_vocabulary(infile, is_dict) + vocab = dict([(tuple(x[:-1])+(x[-1]+'',) ,y) for (x,y) in vocab.items()]) + sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True) + + stats, indices = get_pair_statistics(sorted_vocab) + big_stats = copy.deepcopy(stats) + # threshold is inspired by Zipfian assumption, but should only affect speed + threshold = max(stats.values()) / 10 + for i in range(num_symbols): + if stats: + most_frequent = max(stats, key=lambda x: (stats[x], x)) + + # we probably missed the best pair because of pruning; go back to full statistics + if not stats or (i and stats[most_frequent] < threshold): + prune_stats(stats, big_stats, threshold) + stats = copy.deepcopy(big_stats) + most_frequent = max(stats, key=lambda x: (stats[x], x)) + # threshold is inspired by Zipfian assumption, but should only affect speed + threshold = stats[most_frequent] * i/(i+10000.0) + prune_stats(stats, big_stats, threshold) + + if stats[most_frequent] < min_frequency: + sys.stderr.write('no pair has frequency >= {0}. Stopping\n'.format(min_frequency)) + break + + if verbose: + sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(i, most_frequent[0], most_frequent[1], stats[most_frequent])) + outfile.write('{0} {1}\n'.format(*most_frequent)) + changes = replace_pair(most_frequent, sorted_vocab, indices) + update_pair_statistics(most_frequent, changes, stats, indices) + stats[most_frequent] = 0 + if not i % 100: + prune_stats(stats, big_stats, threshold) + + +if __name__ == '__main__': + + # python 2/3 compatibility + if sys.version_info < (3, 0): + sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) + sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) + sys.stdin = codecs.getreader('UTF-8')(sys.stdin) + else: + sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer) + sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer) + sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer) + + parser = create_parser() + args = parser.parse_args() + + # read/write files as UTF-8 + if args.input.name != '': + args.input = codecs.open(args.input.name, encoding='utf-8') + if args.output.name != '': + args.output = codecs.open(args.output.name, 'w', encoding='utf-8') + + main(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input)