Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add huggingface inferentia serving example #184

Merged
merged 1 commit into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions huggingface/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
special_tokens_map.json
tokenizer.json
tokenizer_config.json
vocab.txt
1 change: 1 addition & 0 deletions huggingface/inferentia/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
qa_payload.json
173 changes: 173 additions & 0 deletions huggingface/inferentia/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
## Serving Huggingface model with AWS Inferentia

[AWS Inferentia](https://aws.amazon.com/machine-learning/inferentia/) is a high performance machine
learning inference chip, custom designed by AWS. Amazon EC2 Inf1 instances are powered by AWS
Inferentia chips, which provides you with the lowest cost per inference in the cloud and lower
the barriers for everyday developers to use machine learning (ML) at scale.

In the demo, you will learn how to deploy PyTorch model with **djl-serving** on Amazon EC2 Inf1 instances.

## Setup environment

### Launch Inf1 EC2 instance

Please launch Inf1 instance by following the [Install Neuron Instructions](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-intro/pytorch-setup/pytorch-install.html#install-neuron-pytorch)

This demo tested on Neuron SDK 1.16.0 and PyTorch 1.9.1 on Ubuntu DLAMI.
Please make sure you have Neuron Runtime 2.x installed:

```
sudo apt-get update -y

# - Stop any existing Neuron runtime 1.0 daemon (neuron-rtd) by calling
sudo systemctl stop neuron-rtd

sudo apt-get install linux-headers-$(uname -r) -y
sudo apt-get install aws-neuron-dkms -y
sudo apt-get install aws-neuron-tools -y
```

### Install AWS Inferentia neuron SDK

The Inferentia neuron SDK is required for converting PyTorch pre-trained model into neuron traced model.

```
python3 -m venv myenv

source myenv/bin/activate
pip install -U pip
pip install torchvision torch-neuron==1.9.1.2.0.318.0 'neuron-cc[tensorflow]==1.7.3.0' --extra-index-url=https://pip.repos.neuron.amazonaws.com
```

After installing the Inferentia neuron SDK, you will find `libtorchneuron.so` is installed in
`myenv/lib/python3.6/site-packages/torch_neuron/lib` folder.
You need configuration environment variable to enable Inferentia for DJL:

```
export PYTORCH_EXTRA_LIBRARY_PATH=$(python -m site | grep $VIRTUAL_ENV | awk -F"'" '{print $2}')/torch_neuron/lib/libtorchneuron.so
```

`libtorchneuron.so` depends on some shared library in its folder, you also need to specify `LD_LIBRARY_PATH` to make it work:

```
export LD_LIBRARY_PATH=$LD_LIBRARYPATH:$(python -m site | grep $VIRTUAL_ENV | awk -F"'" '{print $2}')/torch_neuron/lib/
```

## Compile your model into Neuron traced model

Use the following command to trace a Huggingface questing answering model. The script can be found in the repo at [trace.py](trace.py).
You can also download a traced model from: https://resources.djl.ai/test-models/pytorch/bert_qa-inf1.tar.gz

You can find more details on [Neuron tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/neuron-frameworks/pytorch-neuron/tutorials/index.html).

```
cd huggingface/inferentia
python trace.py
```

Execute above command, now you have a Neuron traced model `model.pt` ready for inference in `questions_answering` folder.

## Install djl-serving

This demo requires **djl-serving** 0.14.0 (not release yet). You need to run **djl-serving** from source.

```
git clone https://github.com/deepjavalibrary/djl-serving.git
```

## Deploy question answering model with DJL

**djl-serving** allows you run Huggingface model in both Java and Python engine.

### Run java engine

Neuron SDK requires precxx11 version of PyTorch native library, you need set the
following environment variable to instruct DJL load precxx11 PyTorch native library:

```
export PYTORCH_PRECXX11=true
```

You can use **djl-serving** to deploy your question answering model out of the box. It will serve
your model in a tensor in tensor out fashion. You can leverage DJL built-in pre-processing and
post-processing to performance BERT Tokenize for you. You can simply provide a `serving.properties`
in the model directory:

```
translatorFactory=ai.djl.pytorch.zoo.nlp.qa.PtBertQATranslatorFactory
tokenizer=distilbert
padding=true
max_length=128
```

**Note:** `padding=true` is required for using neuron sdk, since neuron traced model using a fixed input shape.

```
cd djl-serving/serving

./gradlew :serving:run --args="-m bert_qa::PyTorch:*=file:$HOME/source/djl-demo/huggingface/inferentia/question_answering"
```

### Run Python engine

**djl-serving**'s Python engine is compatible with [TorchServe](https://github.com/pytorch/serve) `.mar` file.
You can deploy TorchServe `.mar` directly in *djl-serving*.

**djl-serving** Python model's script format is similar to TorchServe, but simpler.
See [DJL Python engine](https://github.com/deepjavalibrary/djl-serving/tree/master/engines/python) for how to
write DJL Python model.

```
cd djl-serving/serving

./gradlew :serving:run --args="-m bert_qa::Python:*=file:$HOME/source/djl-demo/huggingface/inferentia/question_answering"
```

## Run inference

**djl-serving** provides a [REST API](https://github.com/deepjavalibrary/djl-serving/blob/master/serving/docs/inference_api.md) allows user to run inference.
The API is compatible with [TorchServe](https://github.com/pytorch/serve) and [MMS](https://github.com/awslabs/multi-model-server).

```
curl -X POST http://127.0.0.1:8080/predictions/bert_qa \
-H "Content-Type: application/json" \
-d '{"question": "How is the weather", "paragraph": "The weather is nice, it is beautiful day"}'
```

## Benchmark

We use apache bench to run benchmark testing:

```
sudo apt-get install -y apache2-utils

echo '{"question": "How is the weather", "paragraph": "The weather is nice, it is beautiful day"}' > qa_payload.json
ab -c 8 -n 8000 -k -p qa_payload.json \
-T "application/json" \
"http://127.0.0.1:8080/predictions/bert_qa"
```

## Performance

By default, **djl-serving** only load the model on single NeuronCore (assume the model compiled for one NeuronCore).
To run the inference on multiple NeuronCore, use the following command to start **djl-serving**:

```
# java engine
./gradlew :serving:run --args="-m bert_qa::PyTorch:nc0;nc1;nc2;nc3=file:$HOME/source/djl-demo/huggingface/inferentia/question_answering"

# python engine
./gradlew :serving:run --args="-m bert_qa::Python:nc0;nc1;nc2;nc3=file:$HOME/source/djl-demo/huggingface/inferentia/question_answering"
```

If your model is traced with 2 NeuronCores, you can use the following command:

```
# java engine
./gradlew :serving:run --args="-m bert_qa::PyTorch:nc0-1;nc2-3=file:$HOME/source/djl-demo/huggingface/inferentia/question_answering"

# python engine
./gradlew :serving:run --args="-m bert_qa::Python:nc0-1;nc2-3=file:$HOME/source/djl-demo/huggingface/inferentia/question_answering"
```


100 changes: 100 additions & 0 deletions huggingface/inferentia/question_answering/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python
#
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import logging
import os

import torch
import torch_neuron
from djl_python import Input
from djl_python import Output
from transformers import AutoTokenizer


class QuestionAnswering(object):

def __init__(self):
self.max_length = 128
self.device = None
self.model = None
self.tokenizer = None
self.initialized = False

def initialize(self, properties: dict):
visible_cores = os.getenv("NEURON_RT_VISIBLE_CORES")
logging.info("NEURON_RT_VISIBLE_CORES: " + visible_cores)

device_id = properties.get("device_id")
device_id = "cpu" if device_id == "-1" else "cuda:" + device_id
self.device = torch.device(device_id)
self.model = torch.jit.load("question_answering.pt").to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(os.getcwd(), do_lower_case=True)
self.initialized = True

def inference(self, inputs: Input):
try:
data = inputs.get_as_json()
question = data["question"]
paragraph = data["paragraph"]
tokens = self.tokenizer.encode_plus(question,
paragraph,
max_length=self.max_length,
truncation=True,
padding='max_length',
add_special_tokens=True,
return_tensors="pt")
input_ids = tokens["input_ids"].to(self.device)
attention_mask = tokens["attention_mask"].to(self.device)

inferences = []
out = self.model(input_ids, attention_mask)
answer_start_scores = out[0]
answer_end_scores = out[1]

num_rows, num_cols = answer_start_scores.shape
for i in range(num_rows):
answer_start_scores_one_seq = answer_start_scores[i].unsqueeze(0)
answer_start = torch.argmax(answer_start_scores_one_seq)
answer_end_scores_one_seq = answer_end_scores[i].unsqueeze(0)
answer_end = torch.argmax(answer_end_scores_one_seq) + 1
token_id = self.tokenizer.convert_ids_to_tokens(input_ids[i].tolist()[answer_start:answer_end])
prediction = self.tokenizer.convert_tokens_to_string(token_id)
inferences.append(prediction)

outputs = Output()
outputs.add_as_json(inferences)
except Exception as e:
logging.error(e, exc_info=True)
# error handling
outputs = Output(code=500, message=str(e))
outputs.add("inference failed", key="data")

return outputs


_model = QuestionAnswering()


def handle(inputs: Input):
"""
Default handler function
"""
if not _model.initialized:
_model.initialize(inputs.get_properties())

if inputs.is_empty():
# initialization request
return None

return _model.inference(inputs)

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
translatorFactory=ai.djl.pytorch.zoo.nlp.qa.PtBertQATranslatorFactory
tokenizer=distilbert
padding=true
max_length=128
4 changes: 4 additions & 0 deletions huggingface/inferentia/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
--extra-index-url=https://pip.repos.neuron.amazonaws.com
transformers
torch-neuron==1.9.1.2.0.318.0
neuron-cc[tensorflow]==1.7.3.0
81 changes: 81 additions & 0 deletions huggingface/inferentia/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import logging
import os
import sys

import torch
import torch_neuron
import transformers
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForQuestionAnswering,
AutoModelForTokenClassification
)

# Enable logging so we can see any important warnings
logger = logging.getLogger('Neuron')
logger.setLevel(logging.INFO)


# compile model to use one neuron core only
# num_cores = 4 # for inf1.xl and inf1.2xl, this value should be 16 on inf1.15xl
# nc_env = ','.join(['1'] * num_cores)
# os.environ['NEURONCORE_GROUP_SIZES'] = nc_env


def transformers_model_downloader(app):
model_file = os.path.join(app, app + ".pt")
if os.path.isfile(model_file):
print("model already downloaded: " + model_file)
return

print("Download model for: ", model_file)
if app == "text_classification":
model_name = "bert-base-uncased"
max_length = 150
model = AutoModelForSequenceClassification.from_pretrained(model_name, torchscript=True, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=True)
elif app == "question_answering":
model_name = "distilbert-base-uncased-distilled-squad"
max_length = 128
model = AutoModelForQuestionAnswering.from_pretrained(model_name, torchscript=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=True)
elif app == "token_classification":
model_name = "bert-base-uncased"
max_length = 150
model = AutoModelForTokenClassification.from_pretrained(model_name, torchscript=True, num_labels=9)
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=True)
else:
print("Unknown application: " + app)
return

text = "How is the weather"
paraphrase = tokenizer.encode_plus(text,
max_length=max_length,
truncation=True,
padding='max_length',
add_special_tokens=True,
return_tensors='pt')
example_inputs = paraphrase['input_ids'], paraphrase['attention_mask']

traced_model = torch.neuron.trace(model, example_inputs)

# Export to saved model
os.makedirs(app, exist_ok=True)
traced_model.save(model_file)

tokenizer.save_pretrained(app)

logging.info("Compile model %s success.", app)


if __name__ == "__main__":
logging.basicConfig(stream=sys.stdout,
format="%(message)s",
level=logging.INFO)
logging.info("Transformers version: %s", transformers.__version__)

transformers_model_downloader("question_answering")
# transformers_model_downloader("text_classification")
# transformers_model_downloader("token_classification")