Skip to content

Commit

Permalink
Merge branch 'master' into audio
Browse files Browse the repository at this point in the history
  • Loading branch information
brianjo authored Aug 8, 2019
2 parents 70095fd + 4d6dafc commit 3cb9b30
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 86 deletions.
Binary file modified _static/img/chatbot/diff.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified _static/img/chatbot/pytorch_workflow.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed _static/img/hybrid.png
Binary file not shown.
158 changes: 77 additions & 81 deletions beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
# -*- coding: utf-8 -*-
"""
Deploying a Seq2Seq Model with the Hybrid Frontend
Deploying a Seq2Seq Model with TorchScript
==================================================
**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
"""


######################################################################
# This tutorial will walk through the process of transitioning a
# sequence-to-sequence model to Torch Script using PyTorch’s Hybrid
# Frontend. The model that we will convert is the chatbot model from the
# `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
# sequence-to-sequence model to TorchScript using the TorchScript
# API. The model that we will convert is the chatbot model from the
# `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
# You can either treat this tutorial as a “Part 2” to the Chatbot tutorial
# and deploy your own pretrained model, or you can start with this
# document and use a pretrained model that we host. In the latter case,
# you can reference the original Chatbot tutorial for details
# regarding data preprocessing, model theory and definition, and model
# training.
#
# What is the Hybrid Frontend?
# What is TorchScript?
# ----------------------------
#
# During the research and development phase of a deep learning-based
Expand All @@ -34,13 +34,13 @@
# to target highly optimized hardware architectures. Also, a graph-based
# representation enables framework-agnostic model exportation. PyTorch
# provides mechanisms for incrementally converting eager-mode code into
# Torch Script, a statically analyzable and optimizable subset of Python
# TorchScript, a statically analyzable and optimizable subset of Python
# that Torch uses to represent deep learning programs independently from
# the Python runtime.
#
# The API for converting eager-mode PyTorch programs into Torch Script is
# The API for converting eager-mode PyTorch programs into TorchScript is
# found in the torch.jit module. This module has two core modalities for
# converting an eager-mode model to a Torch Script graph representation:
# converting an eager-mode model to a TorchScript graph representation:
# **tracing** and **scripting**. The ``torch.jit.trace`` function takes a
# module or function and a set of example inputs. It then runs the example
# input through the function or module while tracing the computational
Expand All @@ -52,19 +52,19 @@
# operations called along the execution route taken by the example input
# will be recorded. In other words, the control flow itself is not
# captured. To convert modules and functions containing data-dependent
# control flow, a **scripting** mechanism is provided. Scripting
# explicitly converts the module or function code to Torch Script,
# including all possible control flow routes. To use script mode, be sure
# to inherit from the the ``torch.jit.ScriptModule`` base class (instead
# of ``torch.nn.Module``) and add a ``torch.jit.script`` decorator to your
# Python function or a ``torch.jit.script_method`` decorator to your
# module’s methods. The one caveat with using scripting is that it only
# supports a restricted subset of Python. For all details relating to the
# supported features, see the Torch Script `language
# reference <https://pytorch.org/docs/master/jit.html>`__. To provide the
# maximum flexibility, the modes of Torch Script can be composed to
# represent your whole program, and these techniques can be applied
# incrementally.
# control flow, a **scripting** mechanism is provided. The
# ``torch.jit.script`` function/decorator takes a module or function and
# does not requires example inputs. Scripting then explicitly converts
# the module or function code to TorchScript, including all control flows.
# One caveat with using scripting is that it only supports a subset of
# Python, so you might need to rewrite the code to make it compatible
# with the TorchScript syntax.
#
# For all details relating to the supported features, see the `TorchScript
# language reference <https://pytorch.org/docs/master/jit.html>`__.
# To provide the maximum flexibility, you can also mix tracing and scripting
# modes together to represent your whole program, and these techniques can
# be applied incrementally.
#
# .. figure:: /_static/img/chatbot/pytorch_workflow.png
# :align: center
Expand Down Expand Up @@ -273,7 +273,7 @@ def indexesFromSentence(voc, sentence):
# used by the ``torch.nn.utils.rnn.pack_padded_sequence`` function when
# padding.
#
# Hybrid Frontend Notes:
# TorchScript Notes:
# ~~~~~~~~~~~~~~~~~~~~~~
#
# Since the encoder’s ``forward`` function does not contain any
Expand All @@ -296,6 +296,7 @@ def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

def forward(self, input_seq, input_lengths, hidden=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
# Convert word indexes to embeddings
embedded = self.embedding(input_seq)
# Pack padded batch of sequences for RNN module
Expand Down Expand Up @@ -325,18 +326,18 @@ def forward(self, input_seq, input_lengths, hidden=None):
#

# Luong attention layer
class Attn(torch.nn.Module):
class Attn(nn.Module):
def __init__(self, method, hidden_size):
super(Attn, self).__init__()
self.method = method
if self.method not in ['dot', 'general', 'concat']:
raise ValueError(self.method, "is not an appropriate attention method.")
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
self.attn = nn.Linear(self.hidden_size, hidden_size)
elif self.method == 'concat':
self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.FloatTensor(hidden_size))

def dot_score(self, hidden, encoder_output):
return torch.sum(hidden * encoder_output, dim=2)
Expand Down Expand Up @@ -383,14 +384,14 @@ def forward(self, hidden, encoder_outputs):
# weighted sum indicating what parts of the encoder’s output to pay
# attention to. From here, we use a linear layer and softmax normalization
# to select the next word in the output sequence.
#
# Hybrid Frontend Notes:

# TorchScript Notes:
# ~~~~~~~~~~~~~~~~~~~~~~
#
# Similarly to the ``EncoderRNN``, this module does not contain any
# data-dependent control flow. Therefore, we can once again use
# **tracing** to convert this model to Torch Script after it is
# initialized and its parameters are loaded.
# **tracing** to convert this model to TorchScript after it
# is initialized and its parameters are loaded.
#

class LuongAttnDecoderRNN(nn.Module):
Expand Down Expand Up @@ -465,18 +466,18 @@ def forward(self, input_step, last_hidden, encoder_outputs):
# terminates either if the ``decoded_words`` list has reached a length of
# *MAX_LENGTH* or if the predicted word is the *EOS_token*.
#
# Hybrid Frontend Notes:
# TorchScript Notes:
# ~~~~~~~~~~~~~~~~~~~~~~
#
# The ``forward`` method of this module involves iterating over the range
# of :math:`[0, max\_length)` when decoding an output sequence one word at
# a time. Because of this, we should use **scripting** to convert this
# module to Torch Script. Unlike with our encoder and decoder models,
# module to TorchScript. Unlike with our encoder and decoder models,
# which we can trace, we must make some necessary changes to the
# ``GreedySearchDecoder`` module in order to initialize an object without
# error. In other words, we must ensure that our module adheres to the
# rules of the scripting mechanism, and does not utilize any language
# features outside of the subset of Python that Torch Script includes.
# rules of the TorchScript mechanism, and does not utilize any language
# features outside of the subset of Python that TorchScript includes.
#
# To get an idea of some manipulations that may be required, we will go
# over the diffs between the ``GreedySearchDecoder`` implementation from
Expand All @@ -491,12 +492,6 @@ def forward(self, input_step, last_hidden, encoder_outputs):
# Changes:
# ^^^^^^^^
#
# - ``nn.Module`` -> ``torch.jit.ScriptModule``
#
# - In order to use PyTorch’s scripting mechanism on a module, that
# module must inherit from the ``torch.jit.ScriptModule``.
#
#
# - Added ``decoder_n_layers`` to the constructor arguments
#
# - This change stems from the fact that the encoder and decoder
Expand All @@ -523,16 +518,9 @@ def forward(self, input_step, last_hidden, encoder_outputs):
# ``self._SOS_token``.
#
#
# - Add the ``torch.jit.script_method`` decorator to the ``forward``
# method
#
# - Adding this decorator lets the JIT compiler know that the function
# that it is decorating should be scripted.
#
#
# - Enforce types of ``forward`` method arguments
#
# - By default, all parameters to a Torch Script function are assumed
# - By default, all parameters to a TorchScript function are assumed
# to be Tensor. If we need to pass an argument of a different type,
# we can use function type annotations as introduced in `PEP
# 3107 <https://www.python.org/dev/peps/pep-3107/>`__. In addition,
Expand All @@ -553,7 +541,7 @@ def forward(self, input_step, last_hidden, encoder_outputs):
# ``self._SOS_token``.
#

class GreedySearchDecoder(torch.jit.ScriptModule):
class GreedySearchDecoder(nn.Module):
def __init__(self, encoder, decoder, decoder_n_layers):
super(GreedySearchDecoder, self).__init__()
self.encoder = encoder
Expand All @@ -564,7 +552,6 @@ def __init__(self, encoder, decoder, decoder_n_layers):

__constants__ = ['_device', '_SOS_token', '_decoder_n_layers']

@torch.jit.script_method
def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_length : int):
# Forward input through encoder model
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
Expand Down Expand Up @@ -613,7 +600,7 @@ def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_len
# an argument, normalizes it, evaluates it, and prints the response.
#

def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
def evaluate(searcher, voc, sentence, max_length=MAX_LENGTH):
### Format input sentence as a batch
# words -> indexes
indexes_batch = [indexesFromSentence(voc, sentence)]
Expand All @@ -632,7 +619,7 @@ def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):


# Evaluate inputs from user input (stdin)
def evaluateInput(encoder, decoder, searcher, voc):
def evaluateInput(searcher, voc):
input_sentence = ''
while(1):
try:
Expand All @@ -643,7 +630,7 @@ def evaluateInput(encoder, decoder, searcher, voc):
# Normalize sentence
input_sentence = normalizeString(input_sentence)
# Evaluate sentence
output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
output_words = evaluate(searcher, voc, input_sentence)
# Format and print response sentence
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
print('Bot:', ' '.join(output_words))
Expand All @@ -652,12 +639,12 @@ def evaluateInput(encoder, decoder, searcher, voc):
print("Error: Encountered unknown word.")

# Normalize input sentence and call evaluate()
def evaluateExample(sentence, encoder, decoder, searcher, voc):
def evaluateExample(sentence, searcher, voc):
print("> " + sentence)
# Normalize sentence
input_sentence = normalizeString(sentence)
# Evaluate sentence
output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
output_words = evaluate(searcher, voc, input_sentence)
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
print('Bot:', ' '.join(output_words))

Expand Down Expand Up @@ -700,14 +687,17 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
# ``checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))``
# line.
#
# Hybrid Frontend Notes:
# TorchScript Notes:
# ~~~~~~~~~~~~~~~~~~~~~~
#
# Notice that we initialize and load parameters into our encoder and
# decoder models as usual. Also, we must call ``.to(device)`` to set the
# device options of the models and ``.eval()`` to set the dropout layers
# to test mode **before** we trace the models. ``TracedModule`` objects do
# not inherit the ``to`` or ``eval`` methods.
# decoder models as usual. If you are using tracing mode(`torch.jit.trace`)
# for some part of your models, you must call .to(device) to set the device
# options of the models and .eval() to set the dropout layers to test mode
# **before** tracing the models. `TracedModule` objects do not inherit the
# ``to`` or ``eval`` methods. Since in this tutorial we are only using
# scripting instead of tracing, we only need to do this before we do
# evaluation (which is the same as we normally do in eager mode).
#

save_dir = os.path.join("data", "save")
Expand Down Expand Up @@ -766,16 +756,14 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):


######################################################################
# Convert Model to Torch Script
# Convert Model to TorchScript
# -----------------------------
#
# Encoder
# ~~~~~~~
#
# As previously mentioned, to convert the encoder model to Torch Script,
# we use **tracing**. Tracing any module requires running an example input
# through the model’s ``forward`` method and trace the computational graph
# that the data encounters. The encoder model takes an input sequence and
# As previously mentioned, to convert the encoder model to TorchScript,
# we use **scripting**. The encoder model takes an input sequence and
# a corresponding lengths tensor. Therefore, we create an example input
# sequence tensor ``test_seq``, which is of appropriate size (MAX_LENGTH,
# 1), contains numbers in the appropriate range
Expand Down Expand Up @@ -803,13 +791,13 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
# ~~~~~~~~~~~~~~~~~~~
#
# Recall that we scripted our searcher module due to the presence of
# data-dependent control flow. In the case of scripting, we do the
# conversion work up front by adding the decorator and making sure the
# implementation complies with scripting rules. We initialize the scripted
# searcher the same way that we would initialize an un-scripted variant.
# data-dependent control flow. In the case of scripting, we do necessary
# language changes to make sure the implementation complies with
# TorchScript. We initialize the scripted searcher the same way that we
# would initialize an un-scripted variant.
#

### Convert encoder model
### Compile the whole greedy search model to TorchScript model
# Create artificial inputs
test_seq = torch.LongTensor(MAX_LENGTH, 1).random_(0, voc.num_words).to(device)
test_seq_length = torch.LongTensor([test_seq.size()[0]]).to(device)
Expand All @@ -824,19 +812,21 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
# Trace the model
traced_decoder = torch.jit.trace(decoder, (test_decoder_input, test_decoder_hidden, test_encoder_outputs))

### Initialize searcher module
scripted_searcher = GreedySearchDecoder(traced_encoder, traced_decoder, decoder.n_layers)
### Initialize searcher module by wrapping ``torch.jit.script`` call
scripted_searcher = torch.jit.script(GreedySearchDecoder(traced_encoder, traced_decoder, decoder.n_layers))




######################################################################
# Print Graphs
# ------------
#
# Now that our models are in Torch Script form, we can print the graphs of
# Now that our models are in TorchScript form, we can print the graphs of
# each to ensure that we captured the computational graph appropriately.
# Since our ``scripted_searcher`` contains our ``traced_encoder`` and
# ``traced_decoder``, these graphs will print inline.
#
# Since TorchScript allow us to recursively compile the whole model
# hierarchy and inline the ``encoder`` and ``decoder`` graph into a single
# graph, we just need to print the `scripted_searcher` graph

print('scripted_searcher graph:\n', scripted_searcher.graph)

Expand All @@ -845,19 +835,25 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
# Run Evaluation
# --------------
#
# Finally, we will run evaluation of the chatbot model using the Torch
# Script models. If converted correctly, the models will behave exactly as
# they would in their eager-mode representation.
# Finally, we will run evaluation of the chatbot model using the TorchScript
# models. If converted correctly, the models will behave exactly as they
# would in their eager-mode representation.
#
# By default, we evaluate a few common query sentences. If you want to
# chat with the bot yourself, uncomment the ``evaluateInput`` line and
# give it a spin.
#


# Use appropriate device
scripted_searcher.to(device)
# Set dropout layers to eval mode
scripted_searcher.eval()

# Evaluate examples
sentences = ["hello", "what's up?", "who are you?", "where am I?", "where are you from?"]
for s in sentences:
evaluateExample(s, traced_encoder, traced_decoder, scripted_searcher, voc)
evaluateExample(s, scripted_searcher, voc)

# Evaluate your input
#evaluateInput(traced_encoder, traced_decoder, scripted_searcher, voc)
Expand All @@ -867,7 +863,7 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
# Save Model
# ----------
#
# Now that we have successfully converted our model to Torch Script, we
# Now that we have successfully converted our model to TorchScript, we
# will serialize it for use in a non-Python deployment environment. To do
# this, we can simply save our ``scripted_searcher`` module, as this is
# the user-facing interface for running inference against the chatbot
Expand Down
Loading

0 comments on commit 3cb9b30

Please sign in to comment.