Skip to content

Commit

Permalink
Merge pull request #488 from OpenNMT/v2
Browse files Browse the repository at this point in the history
V2
  • Loading branch information
guillaumekln authored Oct 1, 2019
2 parents 5b1e521 + 3ab65ec commit 7dbca56
Show file tree
Hide file tree
Showing 237 changed files with 7,428 additions and 12,812 deletions.
27 changes: 7 additions & 20 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,34 +1,21 @@
language: python
python:
- "2.7"
env:
global:
- LATEST_TF_VERSION="1.13.*"
matrix:
- TF_VERSION="1.8.*"
- TF_VERSION="1.9.*"
- TF_VERSION="1.10.*"
- TF_VERSION="1.11.*"
- TF_VERSION="1.12.*"
- TF_VERSION="$LATEST_TF_VERSION"
- "3.5"
before_install:
- pip install tensorflow==$TF_VERSION
- pip install -e .[tests]
before_script:
- wget https://s3.amazonaws.com/opennmt-models/transliteration-aren.tar.gz
- wget https://s3.amazonaws.com/opennmt-models/transliteration-aren-v2.tar.gz
- mkdir -p $TRAVIS_BUILD_DIR/testdata
- tar xf transliteration-aren.tar.gz -C $TRAVIS_BUILD_DIR/testdata
- tar xf transliteration-aren-v2.tar.gz -C $TRAVIS_BUILD_DIR/testdata
script:
- nose2
- nose2 -v
matrix:
include:
- python: "3.6"
env:
- TF_VERSION="$LATEST_TF_VERSION"
install:
- pip install pylint==2.2.* wheel twine
- pip install pylint==2.3.* wheel twine
script:
- nose2
- nose2 -v
- pylint opennmt/
after_success:
- |
Expand All @@ -38,12 +25,12 @@ matrix:
fi
- python: "3.6"
env:
- TF_VERSION="$LATEST_TF_VERSION"
- secure: "YoOVLtg7j03Il5IixovqClvANqQgO/D4idZ2JANmmj+Sea4iVTEIpd4WWLcis99yctY1Kh8aAeGKSFtHUB52+5N2ignewfAdHc3yPjsk8EngAtnhZfc9J40Inky/ogZSVRm4Wh+/3+zvN3VYKVB6hmOgfz9sBxfwBwOOLLliANnyHLo2j0a4GDhkIDy6IDGnl8yO8vMG6W83KVbmyS36Y6zNmMi26/uWTWMISrErrWc5BOAqWVsFvy5lZJDXQUbQ+nSBhprNWCVkZFtWd35EP3f6iVzaR09EfOwJTGKg1TmzmJr1OydNK/hrr/YVlb3byvDFmPPPz7dBu/yXmc26j/N6O5N8BQLqGv2JBNYpy8v4zV77uGjH+W0kXKHz4QAYuQHU40TjXFVa3ZEoWDNYM6T2lvZ+mAfoj9JgsAFQ9XZhmKH6al+5HnDzhohqxTVef32dIDY+mffNb0jofexVCu+ko1wCTjUh4KyzQKSncc4Dq4welYMbfIsLMaeyHdtz5hcJhUlseVbOu6rmLtaBVuYJ350paN4/rWD9svQ+ek53XrSUpiRRLO1VU7ErruralpG/DPyIZIzWKREqGoz1eBoOxjq+iytZqiLbeMwYP3BYmg+RvMzW6m1lpkAeAMMqsEYVj/IdTnTY7jxR1yaV0A7KpuON9pBTSjKy5Ci7cCE="
install:
- pip install -r docs/requirements.txt
- pip install git+git://github.com/drdoctr/doctr@3288a31b66522312a1c541486613b14389dc73f0
script:
- set -e
- python docs/generate-apidoc.py docs/package
- sphinx-build docs docs/build
- doctr deploy --build-tags --branch-whitelist --built-docs docs/build .
51 changes: 49 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,64 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov
* command line options
* configuration files
* checkpoints of non experimental models
* public classes and functions that do not come from third parties
* minimum required TensorFlow version
* classes and functions documented on the online documentation and directly accessible from the top-level `opennmt` package

---

## [Unreleased]

OpenNMT-tf 2.0 is the first major update of the project. The goal of this release is to use the new features and practices introduced by TensorFlow 2.0.

### Breaking changes

See the [2.0 Transition Guide](docs/v2_transition.md) for details about the following changes.

* TensorFlow 2.0 is now required
* Python 3.5 or greater is now required
* Checkpoints are no longer compatible as the code now uses object-based instead of name-based checkpointing (except Transformer checkpoints which are automatically upgraded when loaded)
* The `onmt-main` script now makes use of subparsers which require to move the run type and it specific options to the end of the command
* Some predefined models have been renamed or changed, see the [transition guide](docs/v2_transition.md#changed-predefined-models)
* Some parameters in the YAML configuration have been renamed or changed, see the [transition guide](docs/v2_transition.md#changed-parameters)
* A lot of public classes and functions have changed, see the [API documentation](http://opennmt.net/OpenNMT-tf/package/opennmt.html) for details
* TFRecord files generated with the `opennmt.inputters.write_sequence_record` function or the `onmt-ark-to-records` script are no longer compatible and should be re-generated

This version also changes the public API scope of the project:

* Only public symbols accessible from the top-level `opennmt` package and visible on the online documentation are now part of the public API and covered by backward compatibility guarantees
* The minimum required TensorFlow version is no longer part of the public API and can change in future minor versions

### New features

* Object-based layers extending `tf.keras.layers.Layer`
* Many new reusable modules and layers, see the [API documentation](http://opennmt.net/OpenNMT-tf/package/opennmt.html)
* Replace `tf.estimator` by custom loops for more control and clearer execution path
* Multi-GPU training with `tf.distribute`
* Support early stopping based on any evaluation metrics
* Support GZIP compressed datasets
* `eval` run type accepts `--features_file` and `--labels_file` to evaluate files other than the ones defined in the YAML configuration
* Accept `with_alignments: soft` to output soft alignments during inference or scoring
* `dropout` can be configured in the YAML configuration to override the model values

### Fixes and improvements

* Code and design simplification following TensorFlow 2.0 changes
* Improve logging during training
* Log level configuration also controls TensorFlow C++ logs
* All public classes and functions are now properly accessible from the root package `opennmt`
* Fix dtype error after updating the vocabulary of an averaged checkpoint
* When updating vocabularies, weights of new words are randomly initialized instead of zero initialized

### Missing features

Some features available in OpenNMT-tf v1 were removed or are temporarily missing in this v2 release. If you relied on some of them, please open an issue to track future support or find workarounds.

* Asynchronous distributed training
* Horovod integration
* Adafactor optimizer
* Global parameter initialization strategy (a Glorot/Xavier uniform initialization is used by default)
* Automatic SavedModel export on evaluation
* Average attention network

## [1.25.1](https://github.com/OpenNMT/OpenNMT-tf/releases/tag/v1.25.1) (2019-09-25)

### Fixes and improvements
Expand Down
8 changes: 5 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@ We use GitHub issues for bugs in the code that are **reproducible**. A good bug
* **check if the issue has been fixed** in a more recent version;
* **isolate the problem** to give as much context as possible.

If you have questions on how to use the project or have trouble getting started with it, consider using [our forum](http://forum.opennmt.net/) instead and tagging your topic with the *tensorflow* tag.
If you have questions on how to use the project or have trouble getting started with it, consider using [our forum](http://forum.opennmt.net/) instead and tagging your topic with the *opennmt-tf* tag.

## Requesting features

Do you think a feature is missing or would be a great addition to the project? Please open a GitHub issue to describe it.

## Developing features
## Developing code

*You want to share some code, that's great!*

* If you want to contribute with code but are unsure what to do, look for GitHub issues marked with the *contributions welcome* label. These are developments that we find particularly suited for community contributions.
* If you want to contribute with code but are unsure what to do,
* search for *TODO* comments in the code: these are small dev tasks that should be addressed at some point
* look for GitHub issues marked with the *help wanted* label: these are developments that we find particularly suited for community contributions.
* If you are planning to make a large change to the existing code, consider asking first on [the forum](http://forum.opennmt.net/) or [Gitter](https://gitter.im/OpenNMT/OpenNMT-tf) to confirm that it is welcome.

In any cases, your new code **must**:
Expand Down
129 changes: 83 additions & 46 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# OpenNMT-tf

OpenNMT-tf is a general purpose sequence learning toolkit using TensorFlow. While neural machine translation is the main target task, it has been designed to more generally support:
OpenNMT-tf is a general purpose sequence learning toolkit using TensorFlow 2.0. While neural machine translation is the main target task, it has been designed to more generally support:

* sequence to sequence mapping
* sequence tagging
Expand All @@ -13,37 +13,86 @@ The project is production-oriented and comes with [backward compatibility guaran

## Key features

OpenNMT-tf focuses on modularity to support advanced modeling and training capabilities:
### Modular model architecture

* **arbitrarily complex encoder architectures**<br/>e.g. mixing RNNs, CNNs, self-attention, etc. in parallel or in sequence.
* **hybrid encoder-decoder models**<br/>e.g. self-attention encoder and RNN decoder or vice versa.
* **neural source-target alignment**<br/>train with guided alignment to constrain attention vectors and output alignments as part of the translation API.
* **multi-source training**<br/>e.g. source text and Moses translation as inputs for machine translation.
* **multiple input format**<br/>text with support of mixed word/character embeddings or real vectors serialized in *TFRecord* files.
* **on-the-fly tokenization**<br/>apply advanced tokenization dynamically during the training and detokenize the predictions during inference or evaluation.
* **domain adaptation**<br/>specialize a model to a new domain in a few training steps by updating the word vocabularies in checkpoints.
* **automatic evaluation**<br/>support for saving evaluation predictions and running external evaluators (e.g. BLEU).
* **mixed precision training**<br/>take advantage of the latest NVIDIA optimizations to train models with half-precision floating points.
Models are described with code to allow training custom architectures. For example, the following instance defines a sequence to sequence model with 2 concatenated input features, a self-attentional encoder, and an attentional RNN decoder sharing its input and output embeddings:

and all of the above can be used simultaneously to train novel and complex architectures. See the [predefined models](opennmt/models/catalog.py) to discover how they are defined and the [API documentation](http://opennmt.net/OpenNMT-tf/package/opennmt.html) to customize them.
```python
opennmt.models.SequenceToSequence(
source_inputter=opennmt.inputters.ParallelInputter(
[opennmt.inputters.WordEmbedder(embedding_size=256),
opennmt.inputters.WordEmbedder(embedding_size=256)],
reducer=opennmt.layers.ConcatReducer(axis=-1)),
target_inputter=opennmt.inputters.WordEmbedder(embedding_size=512),
encoder=opennmt.encoders.SelfAttentionEncoder(num_layers=6),
decoder=opennmt.decoders.AttentionalRNNDecoder(
num_layers=4,
num_units=512,
attention_mechanism_class=tfa.seq2seq.LuongAttention),
share_embeddings=opennmt.models.EmbeddingsSharingLevel.TARGET)
```

The [`opennmt` package](http://opennmt.net/OpenNMT-tf/package/opennmt.html) exposes other building blocks that can be used to design:

* [multiple input features](http://opennmt.net/OpenNMT-tf/package/opennmt.inputters.ParallelInputter.html)
* [mixed embedding representation](http://opennmt.net/OpenNMT-tf/package/opennmt.inputters.MixedInputter.html)
* [multi-source context](http://opennmt.net/OpenNMT-tf/package/opennmt.inputters.ParallelInputter.html)
* [cascaded](http://opennmt.net/OpenNMT-tf/package/opennmt.encoders.SequentialEncoder.html) or [multi-column](http://opennmt.net/OpenNMT-tf/package/opennmt.encoders.ParallelEncoder.html) encoder
* [hybrid sequence to sequence models](http://opennmt.net/OpenNMT-tf/package/opennmt.models.SequenceToSequence.html)

Standard models such as the Transformer are defined in a [model catalog](opennmt/models/catalog.py) and can be used without additional configuration.

*Find more information about model configuration in the [documentation](http://opennmt.net/OpenNMT-tf/model.html).*

### Full TensorFlow 2.0 integration

OpenNMT-tf is fully integrated in the TensorFlow 2.0 ecosystem:

* Reusable layers extending [`tf.keras.layers.Layer`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/layers/Layer)
* Multi-GPU training with [`tf.distribute`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/distribute)
* Mixed precision support via a [graph optimization pass](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/train/experimental/enable_mixed_precision_graph_rewrite)
* Visualization with [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
* `tf.function` graph tracing that can be [exported to a SavedModel](http://opennmt.net/OpenNMT-tf/serving.html) and served with [TensorFlow Serving](examples/serving/tensorflow_serving) or [Python](examples/serving/python)

### Dynamic data pipeline

OpenNMT-tf does not require to compile the data before the training. Instead, it can directly read text files and preprocess the data when needed by the training. This allows [on-the-fly tokenization](http://opennmt.net/OpenNMT-tf/tokenization.html) and data augmentation by injecting random noise.

OpenNMT-tf is also compatible with some of the best TensorFlow features:
### Model fine-tuning

* multi-GPU training
* distributed training via [Horovod](https://github.com/uber/horovod) or TensorFlow asynchronous training
* monitoring with [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
* inference with [TensorFlow Serving](https://github.com/OpenNMT/OpenNMT-tf/tree/master/examples/serving) and the [TensorFlow C++ API](https://github.com/OpenNMT/OpenNMT-tf/tree/master/examples/cpp)
OpenNMT-tf supports model fine-tuning workflows:

* Model weights can be transferred to new word vocabularies, e.g. to inject domain terminology before fine-tuning on in-domain data
* [Contrastive learning](https://ai.google/research/pubs/pub48253/) to reduce word omission errors

### Source-target alignment

Sequence to sequence models can be trained with [guided alignment](https://arxiv.org/abs/1607.01628) and alignment information are returned as part of the translation API.

---

OpenNMT-tf also implements most of the techniques commonly used to train and evaluate sequence models, such as:

* automatic evaluation during the training
* multiple decoding strategy: greedy search, beam search, random sampling
* N-best rescoring
* gradient accumulation
* scheduled sampling
* checkpoint averaging
* ... and more!

*See the [documentation](http://opennmt.net/OpenNMT-tf/) to learn how to use these features.*

## Usage

OpenNMT-tf requires:

* Python >= 2.7
* TensorFlow >= 1.4, < 2.0
* Python >= 3.5

We recommend installing it with `pip`:

```bash
pip install --upgrade pip
pip install OpenNMT-tf
```

Expand All @@ -55,59 +104,47 @@ OpenNMT-tf comes with several command line utilities to prepare data, train, and

For all tasks involving a model execution, OpenNMT-tf uses a unique entrypoint: `onmt-main`. A typical OpenNMT-tf run consists of 3 elements:

* the **run** type: `train_and_eval`, `train`, `eval`, `infer`, `export`, or `score`
* the **model** type
* the **parameters** described in a YAML file
* the **run** type such as `train`, `eval`, `infer`, `export`, `score`, `average_checkpoints`, or `update_vocab`

that are passed to the main script:

```
onmt-main <run_type> --model_type <model> --auto_config --config <config_file.yml>
onmt-main --model_type <model> --config <config_file.yml> --auto_config <run_type> <run_options>
```

*For more information and examples on how to use OpenNMT-tf, please visit [our documentation](http://opennmt.net/OpenNMT-tf).*

### Library

OpenNMT-tf also exposes well-defined and stable APIs. Here is an example using the library to encode a sequence using a self-attentional encoder:
OpenNMT-tf also exposes well-defined and stable APIs. Here is an example using the library to run beam search with a self-attentional decoder:

```python
import tensorflow as tf
import opennmt as onmt

tf.enable_eager_execution()
decoder = opennmt.decoders.SelfAttentionDecoder(num_layers=6)
decoder.initialize(vocab_size=32000)

# Build a random batch of input sequences.
inputs = tf.random.uniform([3, 6, 256])
sequence_length = tf.constant([4, 6, 5], dtype=tf.int32)
initial_state = decoder.initial_state(
memory=memory,
memory_sequence_length=memory_sequence_length)

# Encode with a self-attentional encoder.
encoder = onmt.encoders.SelfAttentionEncoder(num_layers=6)
outputs, _, _ = encoder.encode(
inputs,
sequence_length=sequence_length,
mode=tf.estimator.ModeKeys.TRAIN)
batch_size = tf.shape(memory)[0]
start_ids = tf.fill([batch_size], opennmt.START_OF_SENTENCE_ID)

print(outputs)
decoding_result = decoder.dynamic_decode(
target_embedding,
start_ids=start_ids,
initial_state=initial_state,
decoding_strategy=opennmt.utils.BeamSearch(4))
```

For more advanced examples, some online resources are using OpenNMT-tf as a library:

* The directory `examples/library` contains additional examples that use OpenNMT-tf as a library
* [OpenNMT Hackathon 2018](https://github.com/OpenNMT/Hackathon/tree/master/unsupervised-nmt) features a tutorial to implement unsupervised NMT using OpenNMT-tf
* [nmt-wizard-docker](https://github.com/OpenNMT/nmt-wizard-docker) uses the high-level `onmt.Runner` API to wrap OpenNMT-tf with a custom interface for training, translating, and serving

*For a complete overview of the APIs, see the [package documentation](http://opennmt.net/OpenNMT-tf/package/opennmt.html).*

## Acknowledgments

The implementation is inspired by the following:

* [TensorFlow's NMT tutorial](https://github.com/tensorflow/nmt)
* [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor)
* [Google's seq2seq](https://github.com/google/seq2seq)
* [OpenSeq2Seq](https://github.com/NVIDIA/OpenSeq2Seq)

## Additional resources

* [Documentation](http://opennmt.net/OpenNMT-tf)
Expand Down
35 changes: 0 additions & 35 deletions config/models/character_seq2seq.py

This file was deleted.

Loading

0 comments on commit 7dbca56

Please sign in to comment.