Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Added Seq2Seq tasks #37

Merged
merged 43 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
54686cc
Added Seq2Seq tasks
Feb 1, 2021
c400db3
Use rank 0 for model specific params
Feb 1, 2021
81725c6
Add licences
Feb 1, 2021
8a3e2e4
Fix summarization scripts
Feb 1, 2021
0a4c5a9
Fix comments, update from files API
Feb 1, 2021
6e66b01
Add tests
Feb 1, 2021
aa955e4
Add docs
Feb 1, 2021
14f1963
Fix doc header
Feb 1, 2021
310f3ea
Apply suggestions from code review
SeanNaren Feb 1, 2021
8d9775f
Add typing
Feb 1, 2021
abdc843
Add imports, fix docs
Feb 1, 2021
c505c86
Add rouge score for metric
Feb 1, 2021
a5c82fd
Merge branch 'master' into feature/seq2seq
SeanNaren Feb 1, 2021
4888af0
fix imports
Feb 1, 2021
35c772a
fix imports and style
Feb 1, 2021
0e9cfd8
Install sentencepiece for slow tokenizer conversion
Feb 1, 2021
46984d2
yapf
Borda Feb 1, 2021
2411de1
Fixed underlines
Feb 1, 2021
f820919
Fixed doc references
Feb 1, 2021
79cd28c
Added min versions address formatting
Feb 1, 2021
3a8e7a1
Update requirement
Feb 1, 2021
410c500
Fix formatting issues
Feb 1, 2021
f36a80c
Merge branch 'master' into feature/seq2seq
SeanNaren Feb 1, 2021
7997516
Merge branch 'master' into feature/seq2seq
SeanNaren Feb 1, 2021
e7d2a66
Merge branch 'master' into feature/seq2seq
SeanNaren Feb 1, 2021
6859c61
add seq to seq finetuning callback
tchaton Feb 1, 2021
eaede8b
Merge branch 'master' into feature/seq2seq
tchaton Feb 1, 2021
89782a0
docs: link blog
Borda Feb 2, 2021
db2e937
resolve tests
tchaton Feb 2, 2021
d038a27
Merge branch 'feature/seq2seq' of https://github.com/PyTorchLightning…
tchaton Feb 2, 2021
e69f9d2
update
tchaton Feb 2, 2021
409018c
Delete lock file
Feb 2, 2021
b64ebb8
remove download_model
tchaton Feb 2, 2021
6910a7b
Merge branch 'feature/seq2seq' of https://github.com/PyTorchLightning…
tchaton Feb 2, 2021
9d528cd
Revert some changes, update requirements.txt
Feb 2, 2021
3efb1d3
Move to mbart for now, even if it's a large model file
Feb 2, 2021
a2ed0f3
Clean up finetuning module, fix tests plus add todo
Feb 2, 2021
5462f1b
Cleanup
Feb 2, 2021
53997d7
Update flash/text/seq2seq/core/model.py
SeanNaren Feb 2, 2021
9da5051
Remove lock file, add typing
Feb 2, 2021
7d89452
Change to test code
Feb 2, 2021
fb11caf
Swap to module available
Feb 2, 2021
23aacf9
Revert testcode due to test error
Feb 2, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ Lightning Flash
reference/task
reference/image_classification
reference/image_embedder
reference/summarization
reference/text_classification
reference/tabular_classification
reference/translation

.. toctree::
:maxdepth: 1
Expand Down
185 changes: 185 additions & 0 deletions docs/source/reference/summarization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
.. _summarization:

#############
Summarization
#############

********
The task
********

Summarization is the task of summarizing text from a larger document/article into a short sentence/description. For example, taking a web article and describing the topic in a short sentence.
This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case the article would be our input sequence, and the short description/sentence would be the output sequence from the model.

-----

*********
Inference
*********

The :class:`~flash.text.SummarizationTask` is already pre-trained on [XSUM](https://arxiv.org/abs/1808.08745), a dataset of online British Broadcasting Corporation articles.

Use the :class:`~flash.text.SummarizationTask` pretrained model for inference on any string sequence using :func:`~flash.text.SummarizationTask.predict`:

.. code-block:: python

# import our libraries
from flash.text import SummarizationTask


# 2. Load the model from a checkpoint
model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt")

# 2. Perform inference from a sequence
predictions = model.predict([
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"""
Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
tchaton marked this conversation as resolved.
Show resolved Hide resolved
people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
They came to Brixton to see work which has started to revitalise the borough.
It was Charles' first visit to the area since 1996, when he was accompanied by the former
South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue
for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes.
She asked me were they ripe and I said yes - they're from the Dominican Republic.""
Mr Chong is one of 170 local retailers who accept the Brixton Pound.
Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market
or in participating shops.
During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children
nearby on an estate off Coldharbour Lane. Mr West said:
""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man.""
He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'""
Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust.
The trust hopes to restore and refurbish the building,
where once Jimi Hendrix and The Clash played, as a new community and business centre."
"""
])
print(predictions)

Or on a given dataset:

.. code-block:: python

# import our libraries
from pytorch_lightning import Trainer
from flash import download_data
from flash.text import SummarizationData, SummarizationTask

# 2. Load the model from a checkpoint
model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt")

# 3. Create dataset from file
datamodule = SummarizationData.from_file(
predict_file="data/xsum/predict.csv",
input="input",
)

# 4. generate summaries
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)

For more advanced inference options, see :ref:`predictions`.

-----

**********
Finetuning
**********

Say you want to finetune to your own summarization data. We use the XSUM dataset as an example which contains a ``train.csv`` and ``valid.csv``, structured like so:

.. code-block::
tchaton marked this conversation as resolved.
Show resolved Hide resolved

input,target
"The researchers have sequenced the genome of a strain of bacterium that causes the virulent infection...","A team of UK scientists hopes to shed light on the mysteries of bleeding canker, a disease that is threatening the nation's horse chestnut trees."
"Knight was shot in the leg by an unknown gunman at Miami's Shore Club where West was holding a pre-MTV Awards...",Hip hop star Kanye West is being sued by Death Row Records founder Suge Knight over a shooting at a beach party in August 2005.
...

In the above the input column represents the long articles/documents, and the target is the short description used as the target.

All we need is three lines of code to train our model!

.. code-block:: python

# import our libraries
import flash
from flash import download_data
from flash.text import SummarizationData, SummarizationTask

# 1. Download data
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')

# Organize the data
datamodule = SummarizationData.from_files(
train_file="data/xsum/train.csv",
valid_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
input="input",
target="target"
)

# 2. Build the task
model = SummarizationTask()

# 4. Create trainer
trainer = flash.Trainer(max_epochs=1, gpus=1)

# 5. Finetune the task
trainer.finetune(model, datamodule=datamodule)

# 6. Save trainer task
trainer.save_checkpoint("summarization_model_xsum.pt")

----

To run the example:

.. code-block:: bash

python flash_examples/finetuning/summarization.py


------

*********************
Changing the backbone
*********************
By default, we use the `t5 <https://arxiv.org/abs/1910.10683>`_ model for summarization. You can change the model run by passing in the backbone parameter.

.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! Since this is a Seq2Seq task, make sure you use a Seq2Seq model.

.. code-block:: python

datamodule = SummarizationData.from_files(
train_file="data/wmt_en_ro/train.csv",
valid_file="data/wmt_en_ro/valid.csv",
test_file="data/wmt_en_ro/test.csv",
input="input",
target="target",
backbone="google/mt5-small",
)

model = SummarizationTask(backbone="google/mt5-small")

------

*************
API reference
*************

.. _summarization_task:

SummarizationTask
-----------------

.. autoclass:: flash.text.seq2seq.summarization.model.SummarizationTask
:members:
:exclude-members: forward

.. _summarization_data:

SummarizationData
-----------------

.. autoclass:: flash.text.seq2seq.summarization.data.SummarizationData

.. automethod:: flash.text.seq2seq.summarization.data.SummarizationData.from_files
12 changes: 6 additions & 6 deletions docs/source/reference/text_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ Text classification is the task of assigning a piece of text (word, sentence or
Inference
*********

The :class:`~flash.text.TextClassificatier` is already pre-trained on [IMDB](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews), a dataset of highly polarized movie reviews, trained for binary classification- to predict if a given review has a positive or negative sentiment.
The :class:`~flash.text.TextClassifier` is already pre-trained on [IMDB](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews), a dataset of highly polarized movie reviews, trained for binary classification- to predict if a given review has a positive or negative sentiment.

Use the :class:`~flash.text.TextClassificatier` pretrained model for inference on any string sequence using :func:`~flash.text.TextClassifier.predict`:
Use the :class:`~flash.text.TextClassifier` pretrained model for inference on any string sequence using :func:`~flash.text.TextClassifier.predict`:

.. code-block:: python

Expand Down Expand Up @@ -83,10 +83,10 @@ All we need is three lines of code to train our model!

.. code-block:: python
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

# import our libraries
import flash
from flash import download_data
from flash.text import TextClassificationData, TextClassifier
# import our libraries
import flash
from flash import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
Expand Down
167 changes: 167 additions & 0 deletions docs/source/reference/translation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
.. _translation:

###########
Translation
###########

********
The task
********

Translation is the task of translating text from a source language to another, such as English to Romanian.
This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case the English text would be our input sequence, and the Romanian sentence would be the output sequence from the model.

-----

*********
Inference
*********

The :class:`~flash.text.TranslationTask` is already pre-trained on [WMT16 English/Romanian](https://www.statmt.org/wmt16/translation-task.html), a dataset of English to Romanian samples, based on the Europarl corpora.

Use the :class:`~flash.text.TranslationTask` pretrained model for inference on any string sequence using :func:`~flash.text.TranslationTask.predict`:

.. code-block:: python

# import our libraries
from flash.text import TranslationTask


# 2. Load the model from a checkpoint
model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")

# 2. Perform inference from list of sequences
predictions = model.predict([
"BBC News went to meet one of the project's first graduates.",
"A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",
])
print(predictions)

Or on a given dataset:

.. code-block:: python

# import our libraries
from pytorch_lightning import Trainer
from flash import download_data
from flash.text import TranslationData, TranslationTask

# 2. Load the model from a checkpoint
model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")

# 3. Create dataset from file
datamodule = TranslationData.from_file(
predict_file="data/wmt_en_ro/predict.csv",
input="input",
)

# 4. generate translations
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)

For more advanced inference options, see :ref:`predictions`.

-----

**********
Finetuning
**********

Say you want to finetune to your own translation data. We use the English/Romanian WMT16 dataset as an example which contains a ``train.csv`` and ``valid.csv``, structured like so:

.. code-block::

input,target
"Written statements and oral questions (tabling): see Minutes","Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal"
"Closure of sitting","Ridicarea şedinţei"
...

In the above the input/target columns represent the English and Romanian translation respectively.

All we need is three lines of code to train our model!

.. code-block:: python

# import our libraries
import flash
from flash import download_data
from flash.text import TranslationData, TranslationTask

# 1. Download data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/')

# Organize the data
datamodule = TranslationData.from_files(
train_file="data/wmt_en_ro/train.csv",
valid_file="data/wmt_en_ro/valid.csv",
test_file="data/wmt_en_ro/test.csv",
input="input",
target="target",
)

# 2. Build the task
model = TranslationTask()

# 4. Create trainer
trainer = flash.Trainer(max_epochs=5, gpus=1, precision=16)

# 5. Finetune the task
trainer.finetune(model, datamodule=datamodule)

# 6. Save trainer task
trainer.save_checkpoint("translation_model_en_ro.pt")

----

To run the example:

.. code-block:: bash

python flash_examples/finetuning/translation.py


------

*********************
Changing the backbone
*********************
By default, we use the `MarianNMT <https://marian-nmt.github.io/>`_ model for translation. You can change the model run by passing in the backbone parameter.

.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! Since this is a Seq2Seq task, make sure you use a Seq2Seq model.

.. code-block:: python

datamodule = TranslationData.from_files(
train_file="data/wmt_en_ro/train.csv",
valid_file="data/wmt_en_ro/valid.csv",
test_file="data/wmt_en_ro/test.csv",
input="input",
target="target",
tchaton marked this conversation as resolved.
Show resolved Hide resolved
backbone="t5-small",
)

model = TranslationTask(backbone="t5-small")

------

*************
API reference
*************

.. _translation_task:

TranslationTask
---------------

.. autoclass:: flash.text.seq2seq.translation.model.TranslationTask
:members:
:exclude-members: forward

.. _translation_data:

TranslationData
---------------

.. autoclass:: flash.text.seq2seq.translation.data.TranslationData

.. automethod:: flash.text.seq2seq.translation.data.TranslationData.from_files
Loading