Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chronos: train tcn model on gpu and speed up inference on cpu #5594

Merged
merged 6 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/readthedocs/source/_static/js/chronos_tutorial.js
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ $(".checkboxes").click(function(){
var ids = ["ChronosForecaster","TuneaForecasting","AutoTSEstimator","AutoWIDE",
"MultvarWIDE","MultstepWIDE","LSTMForecaster","AutoProphet","AnomalyDetection",
"DeepARmodel","TFTmodel","hyperparameter","taxiDataset","distributedFashion",
"ONNX","Quantize","TCMFForecaster","PenalizeUnderestimation"];
"ONNX","Quantize","TCMFForecaster","PenalizeUnderestimation",
"GPUtrainingCPUacceleration"];
showTutorials(ids);
var disIds = ["simulation"];
disCheck(disIds);
Expand Down Expand Up @@ -94,7 +95,7 @@ $(".checkboxes").click(function(){
disCheck(disIds);
}
else if(vals.includes("customized_model")){
var ids = ["AutoTSEstimator","DeepARmodel","TFTmodel"];
var ids = ["AutoTSEstimator","DeepARmodel","TFTmodel", "GPUtrainingCPUacceleration"];
showTutorials(ids);
var disIds = ["anomaly_detection","simulation","onnxruntime","quantization","distributed"];
disCheck(disIds);
Expand All @@ -114,7 +115,7 @@ $(".checkboxes").click(function(){
disCheck(disIds);
}
else if(vals.includes("forecast") && vals.includes("customized_model")){
var ids = ["DeepARmodel","TFTmodel","AutoTSEstimator"];
var ids = ["DeepARmodel","TFTmodel","AutoTSEstimator","GPUtrainingCPUacceleration"];
showTutorials(ids);
var disIds = ["anomaly_detection","simulation","onnxruntime","quantization","distributed"];
disCheck(disIds);
Expand Down
10 changes: 10 additions & 0 deletions docs/readthedocs/source/doc/Chronos/QuickStart/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@
</details>
<hr>
<details id="GPUtrainingCPUacceleration">
<summary>
<a href="https://github.com/intel-analytics/BigDL/tree/main/python/chronos/example/inference-acceleration">Accelerate the inference speed of model trained on other platform</a>
<p>Tag: <button value="forecast">forecast</button>&nbsp;<button value="customized_model">customized model</button></p>
</summary>
<img src="../../../_images/GitHub-Mark-32px.png"><a href="https://github.com/intel-analytics/BigDL/tree/main/python/chronos/example/inference-acceleration">View source on GitHub</a>
<p>In this example, we show an example to train the model on GPU and accelerate the model by using onnxruntime on CPU.</p>
</details>
<hr>
</div>
<script src="../../../_static/js/chronos_tutorial.js"></script>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from bigdl.chronos.pytorch import TSTrainer as Trainer
from bigdl.chronos.model.tcn import model_creator
from bigdl.chronos.metric.forecast_metrics import Evaluator
from bigdl.chronos.data.repo_dataset import get_public_dataset
from sklearn.preprocessing import StandardScaler

def gen_dataloader():
tsdata_train, tsdata_val,\
tsdata_test = get_public_dataset(name='nyc_taxi',
with_split=True,
val_ratio=0.1,
test_ratio=0.1)

stand = StandardScaler()
for tsdata in [tsdata_train, tsdata_val, tsdata_test]:
tsdata.deduplicate()\
.impute()\
.gen_dt_feature()\
.scale(stand, fit=tsdata is tsdata_train)\
.roll(lookback=48,horizon=1)

tsdata_traindataloader = tsdata_train.to_torch_data_loader(batch_size=32)
tsdata_valdataloader = tsdata_val.to_torch_data_loader(batch_size=32, shuffle=False)
tsdata_testdataloader = tsdata_test.to_torch_data_loader(batch_size=32, shuffle=False)

return tsdata_traindataloader, tsdata_valdataloader, tsdata_testdataloader

def predict_wraper(model, input_sample):
model(input_sample)

if __name__ == '__main__':

# create data loaders for train/valid/test
tsdata_traindataloader,\
tsdata_valdataloader,\
tsdata_testdataloader = gen_dataloader()

# create a model
# This could be an arbitrary model, we choose to use a built-in model TCN here
config = {'input_feature_num':8,
'output_feature_num':1,
'past_seq_len':48,
'future_seq_len':1,
'kernel_size':3,
'repo_initialization':True,
'dropout':0.1,
'seed': 0,
'num_channels':[30]*7
}
model = model_creator(config)
loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(lr=0.001, params=model.parameters())
lit_model = Trainer.compile(model, loss, optimizer)

# train the model
# You may use any method to train the model either on gpu or cpu
trainer = Trainer(max_epochs=3,
accelerator='gpu',
devices=1,
)
trainer.fit(lit_model, tsdata_traindataloader, tsdata_testdataloader)

# get an input sample
x = None
for x, _ in tsdata_traindataloader:
break
input_sample = x[0].unsqueeze(0)

# speed up the model using Chronos TSTrainer
speed_model = Trainer.trace(lit_model, accelerator="onnxruntime", input_sample=input_sample)

# evaluate the model's latency
print("original pytorch latency (ms):", Evaluator.get_latency(predict_wraper, lit_model, input_sample))
print("onnxruntime latency (ms):", Evaluator.get_latency(predict_wraper, speed_model, input_sample))
20 changes: 20 additions & 0 deletions python/chronos/example/inference-acceleration/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Accelerate the inference speed of model trained on other platform

## Introduction
Chronos has many built-in models wrapped in forecasters, detectors and simulators optimized on CPU (especially intel CPU) platform.

While users may want to use their own model or built-in models trained on another platform (e.g. GPU) but prefer to carry out the inferencing process on CPU platform. Chronos can also help users to accelerate their model for inferencing.

In this example, we show an example to train the model on GPU and accelerate the model by using onnxruntime on CPU.

## How to run this example
```bash
python cpu_inference_acceleration.py
```

## Sample output
```bash
Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 288/288
original pytorch latency (ms): {'p50': 1.236, 'p90': 1.472, 'p95': 1.612, 'p99': 32.989}
onnxruntime latency (ms): {'p50': 0.124, 'p90': 0.129, 'p95': 0.148, 'p99': 0.363}
```