diff --git a/docs/source/export.md b/docs/source/export.md index 88d5d03f138..cfc85f145a8 100644 --- a/docs/source/export.md +++ b/docs/source/export.md @@ -67,6 +67,9 @@ int8_onnx_config = Torch2ONNXConfig( ) q_model.export('int8-model.onnx', int8_onnx_config) ``` +> **Note**: Two export examples covering computer vision and natural language processing tasks exist in examples. Users can leverage them to verify the accuracy and performance of the exported ONNX model. + - [Image recognition](/examples/pytorch/image_recognition/torchvision_models/export/fx/) + - [Text classification](/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/) # Appendix diff --git a/examples/.config/model_params_pt2onnx.json b/examples/.config/model_params_pt2onnx.json new file mode 100644 index 00000000000..c383a53e2a3 --- /dev/null +++ b/examples/.config/model_params_pt2onnx.json @@ -0,0 +1,36 @@ +{ + "pt2onnx": { + "resnet18": { + "model_src_dir": "image_recognition/torchvision_models/export/fx", + "source_model_dataset": "/tf_dataset/pytorch/ImageNet/raw", + "target_model_dataset": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ImagenetRaw_small_5000", + "input_model": "resnet18", + "main_script": "main.py", + "batch_size": 100 + }, + "resnet50": { + "model_src_dir": "image_recognition/torchvision_models/export/fx", + "source_model_dataset": "/tf_dataset/pytorch/ImageNet/raw", + "target_model_dataset": "/tf_dataset2/datasets/imagenet/ImagenetRaw/ImagenetRaw_small_5000", + "input_model": "resnet50", + "main_script": "main.py", + "batch_size": 100 + }, + "bert_base_MRPC": { + "model_src_dir": "nlp/huggingface_models/text-classification/export/fx", + "source_model_dataset": "mrpc", + "target_model_dataset": "mrpc", + "input_model": "/tf_dataset/pytorch/glue_data/base_weights/bert_MRPC_output", + "main_script": "run_glue.py", + "batch_size": 64 + }, + "bert_large_MRPC": { + "model_src_dir": "nlp/huggingface_models/text-classification/export/fx", + "source_model_dataset": "mrpc", + "target_model_dataset": "mrpc", + "input_model": "/tf_dataset/pytorch/glue_data/weights/bert_MRPC_output", + "main_script": "run_glue.py", + "batch_size": 64 + } + } +} \ No newline at end of file diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md b/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md new file mode 100644 index 00000000000..e2e9b22b658 --- /dev/null +++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/README.md @@ -0,0 +1,49 @@ +Step-by-Step +============ + +This document describes the step-by-step instructions for reproducing PyTorch tuning results with IntelĀ® Neural Compressor. + +# Prerequisite + +## 1. Environment + +PyTorch 1.8 or higher version is needed with pytorch_fx backend. + +```shell +cd examples/pytorch/image_recognition/torchvision_models/quantization/ptq/cpu/fx +pip install -r requirements.txt +``` +> Note: Validated PyTorch [Version](/docs/source/installation_guide.md#validated-software-environment). + +## 2. Prepare Dataset + +Download [ImageNet](http://www.image-net.org/) Raw image to dir: /path/to/imagenet. The dir include below folder: + +```bash +ls /path/to/pytorch-imagenet +train val +ls /path/to/onnx-imagenet-validation +ILSVRC2012_img_val val.txt +``` + +# Run +### 1. To get the exported model: + +Run run_export.sh to get ONNX model from PyTorch model. +```bash +# export fp32 model +bash run_export.sh --input_model=resnet50 --dtype=fp32 --dataset_location=/path/to/pytorch-imagenet --output_model=resnet50-fp32.onnx +# export int8 model +bash run_export.sh --input_model=resnet50 --dtype=int8 --quant_format=[QDQ|QLinear] --dataset_location=/path/to/pytorch-imagenet --output_model=resnet50-int8.onnx +``` + +### 2. To get the benchmark of exported and tuned models, includes Batch_size and Throughput: +Run run_benchmark.sh to benchmark the accuracy and performance of ONNX models and PyTorch model. +```bash +# benchmark ONNX model +bash run_benchmark.sh --input_model=[resnet50-fp32.onnx|resnet50-int8.onnx] --dataset_location=/path/to/onnx-imagenet-validation --mode=[accuracy|performance] --batch_size=[16] +# benchmark PyTorch model +bash run_benchmark.sh --input_model=[resnet50|/path/to/saved_results] --dataset_location=/path/to/pytorch-imagenet --mode=[accuracy|performance] --int8=[true|false] --batch_size=[16] +``` + +> Note: All torchvision model names can be passed as long as they are included in `torchvision.models`, below are some examples. diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py b/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py new file mode 100644 index 00000000000..5c2e29d8eaa --- /dev/null +++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/main.py @@ -0,0 +1,406 @@ +import argparse +import os +import random +import shutil +import time +import warnings +import sys + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models + +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('-t', '--tune', dest='tune', action='store_true', + help='tune best int8 model on calibration dataset') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--ppn', default=1, type=int, + help='number of processes on each node of distributed training') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') +parser.add_argument('-i', "--iter", default=0, type=int, + help='For accuracy measurement only.') +parser.add_argument('-w', "--warmup_iter", default=5, type=int, + help='For benchmark measurement only.') +parser.add_argument('--performance', dest='performance', action='store_true', + help='run benchmark') +parser.add_argument('-r', "--accuracy", dest='accuracy', action='store_true', + help='For accuracy measurement only.') +parser.add_argument("--tuned_checkpoint", default='./saved_results', type=str, metavar='PATH', + help='path to checkpoint tuned by Neural Compressor (default: ./)') +parser.add_argument("--output_model", default='./model.onnx', type=str, metavar='PATH', + help='path to onnx model exported by Neural Compressor (default: ./)') +parser.add_argument('--int8', dest='int8', action='store_true', help='run benchmark') +parser.add_argument('--export', dest='export', action='store_true', help='run export') +parser.add_argument('--export_dtype', default='fp32', choices=['fp32', 'int8'], + help='choose the data type [fp32/int8] of PyTorch model to be exported.') +parser.add_argument('--quant_format', default='QDQ', choices=['QDQ', 'QLinear'], + help='choose the format [QDQ/QLinear] of int8 ONNX model exported.') + +best_acc1 = 0 + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + + if args.pretrained: + print("=> using pre-trained model '{}'".format(args.arch)) + model = models.__dict__[args.arch](pretrained=True) + else: + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss() + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True, sampler=None) + + val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + def eval_func(model): + accu = validate(val_loader, model, criterion, args) + return float(accu) + + from neural_compressor.config import Torch2ONNXConfig + if args.export and args.export_dtype == 'fp32': + from neural_compressor.model import Model + inc_model = Model(model) + fp32_onnx_config = Torch2ONNXConfig( + dtype="fp32", + example_inputs=torch.randn(1, 3, 224, 224), + input_names=['input'], + output_names=['output'], + dynamic_axes={"input": {0: "batch_size"}, + "output": {0: "batch_size"}}, + ) + inc_model.export(args.output_model, fp32_onnx_config) + + if args.export and args.export_dtype == 'int8': + from neural_compressor import PostTrainingQuantConfig + from neural_compressor import quantization + if 'efficient' in args.arch: + # To reduce tuning time and get the result faster, the efficient net series model + # use the MSE_V2 strategy by default. + from neural_compressor.config import TuningCriterion + tuning_criterion = TuningCriterion(strategy="mse_v2") + conf = PostTrainingQuantConfig(tuning_criterion=tuning_criterion) + else: + conf = PostTrainingQuantConfig() + q_model = quantization.fit(model, + conf, + calib_dataloader=val_loader, + eval_func=eval_func) + q_model.save(args.tuned_checkpoint) + int8_onnx_config = Torch2ONNXConfig( + dtype="int8", + opset_version=14, + quant_format="QDQ", + example_inputs=torch.randn(1, 3, 224, 224), + input_names=['input'], + output_names=['output'], + dynamic_axes={"input": {0: "batch_size"}, + "output": {0: "batch_size"}}, + ) + q_model.export(args.output_model, int8_onnx_config) + return + + if args.performance or args.accuracy: + model.eval() + if args.int8: + from neural_compressor.utils.pytorch import load + new_model = load(os.path.abspath(os.path.expanduser(args.tuned_checkpoint)), + model, + dataloader=val_loader) + else: + new_model = model + if args.performance: + from neural_compressor.config import BenchmarkConfig + from neural_compressor import benchmark + b_conf = BenchmarkConfig(warmup=5, + iteration=args.iter, + cores_per_instance=4, + num_of_instance=1) + benchmark.fit(new_model, b_conf, b_dataloader=val_loader) + if args.accuracy: + validate(val_loader, new_model, criterion, args) + return + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, + top5, prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + input = input.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1[0], input.size(0)) + top5.update(acc5[0], input.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.print(i) + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + for i, (input, target) in enumerate(val_loader): + if i >= args.warmup_iter: + start = time.time() + if args.gpu is not None: + input = input.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1[0], input.size(0)) + top5.update(acc5[0], input.size(0)) + + # measure elapsed time + if i >= args.warmup_iter: + batch_time.update(time.time() - start) + + if i % args.print_freq == 0: + progress.print(i) + + if args.iter > 0 and i >= (args.warmup_iter + args.iter - 1): + break + + print('Batch size = %d' % args.batch_size) + print('Accuracy: {top1:.5f} Accuracy@5 {top5:.5f}' + .format(top1=(top1.avg / 100), top5=(top5.avg / 100))) + + return top1.avg + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, *meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def print(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/onnx_evaluation.py b/examples/pytorch/image_recognition/torchvision_models/export/fx/onnx_evaluation.py new file mode 100644 index 00000000000..fb8021a4a14 --- /dev/null +++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/onnx_evaluation.py @@ -0,0 +1,281 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation + + +import logging +import argparse +import cv2 +import numpy as np +import onnx +import re +import os +from PIL import Image +import onnxruntime as ort +from sklearn.metrics import accuracy_score + +logger = logging.getLogger(__name__) +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.WARN) + +class Squeeze: + def __call__(self, sample): + preds, labels = sample + return np.squeeze(preds), labels + +def _topk_shape_validate(preds, labels): + # preds shape can be Nxclass_num or class_num(N=1 by default) + # it's more suitable for 'Accuracy' with preds shape Nx1(or 1) output from argmax + if isinstance(preds, int): + preds = [preds] + preds = np.array(preds) + elif isinstance(preds, np.ndarray): + preds = np.array(preds) + elif isinstance(preds, list): + preds = np.array(preds) + preds = preds.reshape((-1, preds.shape[-1])) + + # consider labels just int value 1x1 + if isinstance(labels, int): + labels = [labels] + labels = np.array(labels) + elif isinstance(labels, tuple): + labels = np.array([labels]) + labels = labels.reshape((labels.shape[-1], -1)) + elif isinstance(labels, list): + if isinstance(labels[0], int): + labels = np.array(labels) + labels = labels.reshape((labels.shape[0], 1)) + elif isinstance(labels[0], tuple): + labels = np.array(labels) + labels = labels.reshape((labels.shape[-1], -1)) + else: + labels = np.array(labels) + # labels most have 2 axis, 2 cases: N(or Nx1 sparse) or Nxclass_num(one-hot) + # only support 2 dimension one-shot labels + # or 1 dimension one-hot class_num will confuse with N + + if len(preds.shape) == 1: + N = 1 + class_num = preds.shape[0] + preds = preds.reshape([-1, class_num]) + elif len(preds.shape) >= 2: + N = preds.shape[0] + preds = preds.reshape([N, -1]) + class_num = preds.shape[1] + + label_N = labels.shape[0] + assert label_N == N, 'labels batch size should same with preds' + labels = labels.reshape([N, -1]) + # one-hot labels will have 2 dimension not equal 1 + if labels.shape[1] != 1: + labels = labels.argsort()[..., -1:] + return preds, labels + +class TopK: + def __init__(self, k=1): + self.k = k + self.num_correct = 0 + self.num_sample = 0 + + def update(self, preds, labels, sample_weight=None): + preds, labels = _topk_shape_validate(preds, labels) + preds = preds.argsort()[..., -self.k:] + if self.k == 1: + correct = accuracy_score(preds, labels, normalize=False) + self.num_correct += correct + + else: + for p, l in zip(preds, labels): + # get top-k labels with np.argpartition + # p = np.argpartition(p, -self.k)[-self.k:] + l = l.astype('int32') + if l in p: + self.num_correct += 1 + + self.num_sample += len(labels) + + def reset(self): + self.num_correct = 0 + self.num_sample = 0 + + def result(self): + if self.num_sample == 0: + logger.warning("Sample num during evaluation is 0.") + return 0 + elif getattr(self, '_hvd', None) is not None: + allgather_num_correct = sum(self._hvd.allgather_object(self.num_correct)) + allgather_num_sample = sum(self._hvd.allgather_object(self.num_sample)) + return allgather_num_correct / allgather_num_sample + return self.num_correct / self.num_sample + +class Dataloader: + def __init__(self, dataset_location, image_list, batch_size=1): + self.batch_size = batch_size + self.image_list = [] + self.label_list = [] + self.resize_side = 256 + self.crop_size = 224 + self.mean_value = [0.485, 0.456, 0.406] + self.std_value = [0.229, 0.224, 0.225] + with open(image_list, 'r') as f: + for s in f: + image_name, label = re.split(r"\s+", s.strip()) + src = os.path.join(dataset_location, image_name) + if not os.path.exists(src): + continue + + self.image_list.append(src) + self.label_list.append(int(label)) + + def __iter__(self): + batched_image = None + batched_label = None + for index, (src, label) in enumerate(zip(self.image_list, self.label_list)): + with Image.open(src) as image: + image = np.array(image.convert('RGB')).astype(np.float32) + height, width = image.shape[0], image.shape[1] + scale = self.resize_side / width if height > width else self.resize_side / height + new_height = int(height*scale) + new_width = int(width*scale) + image = cv2.resize(image, (new_height, new_width)) + image = image / 255. + shape = image.shape + y0 = (shape[0] - self.crop_size) // 2 + x0 = (shape[1] - self.crop_size) // 2 + if len(image.shape) == 2: + image = np.array([image]) + image = np.repeat(image, 3, axis=0) + image = image.transpose(1, 2, 0) + image = image[y0:y0+self.crop_size, x0:x0+self.crop_size, :] + image = ((image - self.mean_value)/self.std_value).astype(np.float32) + image = image.transpose(2, 0, 1) + image = np.expand_dims(image, axis=0) + label = np.expand_dims(label, axis=0) + if batched_label is None: + batched_image = image + batched_label = label + else: + batched_image = np.append(batched_image, image, axis=0) + batched_label = np.append(batched_label, label, axis=0) + if (index + 1) % self.batch_size == 0: + yield batched_image, batched_label + batched_image = None + batched_label = None + if (index + 1) % self.batch_size != 0: + yield batched_image, batched_label + +def eval_func(model, dataloader, metric, postprocess): + metric.reset() + sess = ort.InferenceSession(model.SerializeToString(), providers=ort.get_available_providers()) + ort_inputs = {} + input_names = [i.name for i in sess.get_inputs()] + for input_data, label in dataloader: + output = sess.run(None, dict(zip(input_names, [input_data]))) + output, label = postprocess((output, label)) + metric.update(output, label) + return metric.result() + +if __name__ == "__main__": + logger.info("Evaluating ONNXRuntime full precision accuracy and performance:") + parser = argparse.ArgumentParser( + description="Googlenet fine-tune examples for image classification tasks.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + '--model_path', + type=str, + help="Pre-trained model on onnx file" + ) + parser.add_argument( + '--dataset_location', + type=str, + help="Imagenet data path" + ) + parser.add_argument( + '--benchmark', + action='store_true', \ + default=False + ) + parser.add_argument( + '--tune', + action='store_true', \ + default=False, + help="whether quantize the model" + ) + parser.add_argument( + '--output_model', + type=str, + help="output model path" + ) + parser.add_argument( + '--batch_size', + type=int, + help="batch_size of dataloader" + ) + parser.add_argument( + '--iters', + type=int, + help="iters of dataloader" + ) + parser.add_argument( + '--mode', + type=str, + help="benchmark mode of performance or accuracy" + ) + parser.add_argument( + '--quant_format', + type=str, + default='default', + choices=['default', 'QDQ', 'QOperator'], + help="quantization format" + ) + args = parser.parse_args() + + model = onnx.load(args.model_path) + data_path = os.path.join(args.dataset_location, 'ILSVRC2012_img_val') + label_path = os.path.join(args.dataset_location, 'val.txt') + dataloader = Dataloader(data_path, label_path, args.batch_size) + top1 = TopK() + postprocess = Squeeze() + def eval(onnx_model): + return eval_func(onnx_model, dataloader, top1, postprocess) + + if args.benchmark: + if args.mode == 'performance': + from neural_compressor.benchmark import fit + from neural_compressor.config import BenchmarkConfig + conf = BenchmarkConfig( + warmup=10, + iteration=args.iters, + cores_per_instance=4, + num_of_instance=1 + ) + fit(model, conf, b_dataloader=dataloader) + elif args.mode == 'accuracy': + acc_result = eval(model) + print("Batch size = %d" % dataloader.batch_size) + print("Accuracy: %.5f" % acc_result) + if args.tune: + from neural_compressor import quantization, PostTrainingQuantConfig + config = PostTrainingQuantConfig(quant_format=args.quant_format) + + q_model = quantization.fit(model, config, calib_dataloader=dataloader, + eval_func=eval) + + q_model.save(args.output_model) diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/requirements.txt b/examples/pytorch/image_recognition/torchvision_models/export/fx/requirements.txt new file mode 100644 index 00000000000..94f1a7356fe --- /dev/null +++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/requirements.txt @@ -0,0 +1,3 @@ +neural-compressor +torch>=1.9.0 +torchvision>=0.10.0 diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/run_benchmark.sh b/examples/pytorch/image_recognition/torchvision_models/export/fx/run_benchmark.sh new file mode 100644 index 00000000000..fd6b18749e5 --- /dev/null +++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/run_benchmark.sh @@ -0,0 +1,82 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + iters=100 + batch_size=32 + tuned_checkpoint=saved_results + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --mode=*) + mode=$(echo $var |cut -f2 -d=) + ;; + --batch_size=*) + batch_size=$(echo $var |cut -f2 -d=) + ;; + --iters=*) + iters=$(echo ${var} |cut -f2 -d=) + ;; + --int8=*) + int8=$(echo ${var} |cut -f2 -d=) + ;; + --config=*) + tuned_checkpoint=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + + +# run_benchmark +function run_benchmark { + if [[ ${input_model: -5:5} == ".onnx" ]]; then + python onnx_evaluation.py \ + --model_path ${input_model} \ + --dataset_location ${dataset_location} \ + --batch_size=${batch_size} \ + --iters=${iters} \ + --mode=${mode} \ + --benchmark + else + if [[ ${mode} == "accuracy" ]]; then + mode_cmd=" --accuracy" + elif [[ ${mode} == "performance" ]]; then + mode_cmd=" --iter ${iters} --performance " + else + echo "Error: No such mode: ${mode}" + exit 1 + fi + + extra_cmd="" + if [[ ${int8} == "true" ]]; then + extra_cmd=$extra_cmd" --int8" + fi + + python main.py \ + --pretrained \ + --tuned_checkpoint ${tuned_checkpoint} \ + -b ${batch_size} \ + -a ${input_model} \ + ${mode_cmd} \ + ${extra_cmd} \ + ${dataset_location} + fi +} + +main "$@" diff --git a/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh b/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh new file mode 100644 index 00000000000..366db7e850b --- /dev/null +++ b/examples/pytorch/image_recognition/torchvision_models/export/fx/run_export.sh @@ -0,0 +1,55 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_tuning + +} + +# init params +function init_params { + dtype='fp32' + quant_format='QDQ' # or QLinear + tuned_checkpoint=saved_results + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --output_model=*) + output_model=$(echo $var |cut -f2 -d=) + ;; + --dtype=*) + dtype=$(echo $var |cut -f2 -d=) + ;; + --quant_format=*) + quant_format=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + +# run_tuning +function run_tuning { + python main.py \ + --pretrained \ + -t \ + -a ${input_model} \ + -b 30 \ + --tuned_checkpoint ${tuned_checkpoint} \ + --output_model ${output_model} \ + --export \ + --export_dtype ${dtype} \ + --quant_format ${quant_format} \ + ${dataset_location} + +} + +main "$@" diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md new file mode 100644 index 00000000000..432eed9a1d2 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/README.md @@ -0,0 +1,62 @@ +Step-by-Step +============ + +This document is used to list steps of reproducing PyTorch BERT tuning zoo result. +Original BERT documents please refer to [BERT README](../../../../common/README.md) and [README](../../../../common/examples/text-classification/README.md). + +> **Note** +> +> Dynamic Quantization is the recommended method for huggingface models. + +# Prerequisite + +## 1. Installation + +### Python Version + +Recommend python 3.6 or higher version. + +#### Install BERT model + +```bash +pip install transformers +``` + +#### Install dependency + +```shell +pip install -r requirements.txt +``` + +#### Install PyTorch +```shell +pip install torch +``` + +## 2. Prepare pretrained model + +Before use IntelĀ® Neural Compressor, you should fine tune the model to get pretrained model or reuse fine-tuned models in [model hub](https://huggingface.co/models), You should also install the additional packages required by the examples. + +## 3. Prepare dataset + +Please pass in the name of dataset, supported datasets are 'mrpc', 'qqp', 'qnli', 'rte', 'sts-b', 'cola', 'mnli', 'wnli', 'sst2'. + + +# Run + +### 1. To get the exported model: + +```bash +# export fp32 model +bash run_export.sh --input_model=[model_name_or_path] --dataset_location=[dataset_name] --dtype=fp32 --output_model=bert-fp32.onnx +# export int8 model +bash run_export.sh --input_model=[model_name_or_path] --dataset_location=[dataset_name] --dtype=int8 --quant_format=[QDQ/QLinear] --output_model=bert-int8.onnx +``` + +### 2. To get the benchmark of exported and tuned models, includes Batch_size and Throughput: +```bash +# benchmark ONNX model +bash run_benchmark.sh --input_model=[bert-fp32.onnx|bert-int8.onnx] --dataset_location=[dataset_name] --tokenizer=[model_name_or_path] --mode=[accuracy|performance] --batch_size=[16] +# benchmark PyTorch model +bash run_benchmark.sh --input_model=[model_name_or_path|/path/to/saved_results] --dataset_location=[dataset_name] --mode=[accuracy|performance] --int8=[true|false] --batch_size=[16] +``` diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/onnx_evaluation.py b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/onnx_evaluation.py new file mode 100644 index 00000000000..0716edc6f2b --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/onnx_evaluation.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation + +from cProfile import label +import logging +import argparse +import onnx +import onnxruntime as ort +import transformers +import os +import torch +import numpy as np +from datasets import load_dataset +from transformers import AutoTokenizer +from neural_compressor.data.dataloaders.onnxrt_dataloader import DefaultDataLoader +from neural_compressor.data.datasets.dummy_dataset import DummyDataset + + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +class ONNXRTBertDataset: + def __init__(self, task, model_name_or_path, max_seq_length=128, data_dir=None): + raw_dataset = load_dataset('glue', task, cache_dir=data_dir, split='validation') + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + sentence1_key, sentence2_key = task_to_keys[task] + origin_keys = raw_dataset[0].keys() + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding="max_length", max_length=max_seq_length, truncation=True) + if "label" in examples: + result["label"] = examples["label"] + return result + + self.dataset = raw_dataset.map( + preprocess_function, batched=True, load_from_cache_file=True, remove_columns=origin_keys + ) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + batch = {k: np.asarray(v) for k, v in self.dataset[index].items()} + label = batch.pop('label') + return batch, label + + +class INCDataloader(): + def __init__(self, dataset, batch_size=1): + import math + self.dataset = dataset + self.batch_size = batch_size + self.length = math.ceil(len(self.dataset) // self.batch_size) + self.example_input = self.dataset[0][0] + + def __iter__(self): + batched_input = {k: None for k in self.example_input} + batched_label = None + for idx, (input, label) in enumerate(self.dataset): + label = np.expand_dims(label, axis=0) + for k, v in input.items(): + v = np.expand_dims(v, axis=0) + if batched_input[k] is None: + batched_input[k] = v + else: + batched_input[k] = np.append(batched_input[k], v, axis=0) + if batched_label is None: + batched_label = label + else: + batched_label = np.append(batched_label, label, axis=0) + if (idx+1) % self.batch_size == 0: + yield batched_input, batched_label + batched_input = {k: None for k in self.example_input} + batched_label = None + if (idx+1) % self.batch_size != 0: + yield batched_input, batched_label + + def __len__(self): + return self.length + +class ONNXRTGLUE: + """Computes GLUE score. + + Args: + task (str, default=mrpc): The name of the task. + Choices include mrpc, qqp, qnli, rte, + sts-b, cola, mnli, wnli. + + """ + def __init__(self, task='mrpc'): + assert task in ['mrpc', 'qqp', 'qnli', 'rte', 'sts-b', 'cola', \ + 'mnli', 'wnli', 'sst2'], 'Unsupported task type' + self.pred_list = None + self.label_list = None + self.task = task + self.return_key = { + "cola": "mcc", + "mrpc": "f1", + "sts-b": "corr", + "qqp": "acc", + "mnli": "mnli/acc", + "qnli": "acc", + "rte": "acc", + "wnli": "acc", + "sst2": "acc" + } + + def update(self, preds, labels): + if self.pred_list is None: + self.pred_list = preds + self.label_list = labels + else: + self.pred_list = np.append(self.pred_list, preds, axis=0) + self.label_list = np.append(self.label_list, labels, axis=0) + + def reset(self): + """clear preds and labels storage""" + self.pred_list = None + self.label_list = None + + def result(self): + """calculate metric""" + output_mode = transformers.glue_output_modes[self.task] + + if output_mode == "classification": + processed_preds = np.argmax(self.pred_list, axis=1) + elif output_mode == "regression": + processed_preds = np.squeeze(self.pred_list) + result = transformers.glue_compute_metrics(\ + self.task, processed_preds, self.label_list) + return result[self.return_key[self.task]] + +logger = logging.getLogger(__name__) +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.WARN) + +if __name__ == "__main__": + logger.info('Evaluating ONNXRuntime full precision accuracy and performance:') + parser = argparse.ArgumentParser( + description='BERT fine-tune examples for classification/regression tasks.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--model_path', + type=str, + help="Pre-trained resnet50 model on onnx file" + ) + parser.add_argument( + '--benchmark', + action='store_true', \ + default=False, \ + help="benchmark mode of performance" + ) + parser.add_argument( + '--config', + type=str, + help="config yaml path" + ) + parser.add_argument( + '--output_model', + type=str, + default=None, + help="output model path" + ) + parser.add_argument( + '--accuracy', + action='store_true', \ + default=False, \ + help="benchmark mode of accuracy" + ) + parser.add_argument( + '--data_path', + type=str, + help="input data path" + ) + parser.add_argument( + '--batch_size', + default=8, + type=int, + ) + parser.add_argument( + '--model_name_or_path', + type=str, + help="pretrained model name or path" + ) + parser.add_argument( + '--task', + type=str, + choices=['mrpc', 'qqp', 'qnli', 'rte', 'sts-b', 'cola', \ + 'mnli', 'wnli', 'sst2'], + help="GLUE task name" + ) + parser.add_argument( + '--max_seq_length', + default=128, + type=int, + ) + + args = parser.parse_args() + + dataset = ONNXRTBertDataset(task=args.task, + model_name_or_path=args.model_name_or_path, + max_seq_length =args.max_seq_length) + dataloader = INCDataloader(dataset, args.batch_size) + metric = ONNXRTGLUE(args.task) + + def eval_func(model): + metric.reset() + from tqdm import tqdm + session = ort.InferenceSession(model.SerializeToString(), None) + for inputs, labels in tqdm(dataloader): + predictions = session.run(None, inputs) + metric.update(predictions[0], labels) + return metric.result() + + model = onnx.load(args.model_path) + if args.benchmark: + from neural_compressor.benchmark import fit + from neural_compressor.config import BenchmarkConfig + conf = BenchmarkConfig(iteration=100, + cores_per_instance=4, + num_of_instance=1) + fit(model, conf, b_dataloader=dataloader) + elif args.accuracy: + acc_result = eval_func(model) + print("Batch size = %d" % args.batch_size) + print("Accuracy: %.5f" % acc_result) diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/requirements.txt b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/requirements.txt new file mode 100644 index 00000000000..3ae7360761d --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/requirements.txt @@ -0,0 +1,10 @@ +datasets>=1.1.3 +sentencepiece!=0.1.92 +protobuf +scipy +scikit-learn +Keras-Preprocessing +onnx +onnxruntime +transformers >= 4.16.0 +torch>=1.9.0 diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_benchmark.sh b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_benchmark.sh new file mode 100644 index 00000000000..40ca14c55bd --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_benchmark.sh @@ -0,0 +1,90 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + iters=100 + batch_size=16 + tuned_checkpoint=saved_results + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --mode=*) + mode=$(echo $var |cut -f2 -d=) + ;; + --batch_size=*) + batch_size=$(echo $var |cut -f2 -d=) + ;; + --iters=*) + iters=$(echo ${var} |cut -f2 -d=) + ;; + --int8=*) + int8=$(echo ${var} |cut -f2 -d=) + ;; + --config=*) + tuned_checkpoint=$(echo $var |cut -f2 -d=) + ;; + --tokenizer=*) + tokenizer=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + + +# run_benchmark +function run_benchmark { + extra_cmd='' + MAX_SEQ_LENGTH=128 + if [[ ${mode} == "accuracy" ]]; then + mode_cmd="--accuracy" + elif [[ ${mode} == "performance" ]]; then + mode_cmd=" --benchmark" + else + echo "Error: No such mode: ${mode}" + exit 1 + fi + + if [[ ${input_model: -5:5} == ".onnx" ]]; then + # fetch tokenizer configuration for dataset. + python onnx_evaluation.py \ + --model_name_or_path ${tokenizer} \ + --model_path ${input_model} \ + --task ${dataset_location} \ + --batch_size=${batch_size} \ + ${mode_cmd} + else + extra_cmd='--model_name_or_path '${input_model} + if [[ ${int8} == "true" ]]; then + extra_cmd='--model_name_or_path '${tuned_checkpoint} + extra_cmd=$extra_cmd" --int8" + fi + echo $extra_cmd + python -u run_glue.py \ + --task_name ${dataset_location} \ + --do_eval \ + --max_seq_length ${MAX_SEQ_LENGTH} \ + --per_device_eval_batch_size ${batch_size} \ + --no_cuda \ + --output_dir ./output_log \ + --overwrite_output_dir \ + ${mode_cmd} \ + ${extra_cmd} + fi +} + +main "$@" diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh new file mode 100644 index 00000000000..02226b0aa83 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_export.sh @@ -0,0 +1,66 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_tuning + +} + +# init params +function init_params { + dtype='fp32' + quant_format='QDQ' # or QLinear + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --output_model=*) + output_model=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --dtype=*) + dtype=$(echo $var |cut -f2 -d=) + ;; + --quant_format=*) + quant_format=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + +# run_tuning +function run_tuning { + # tuned_checkpoint is used to save torch int8 model. + tuned_checkpoint=saved_results + extra_cmd='' + batch_size=16 + MAX_SEQ_LENGTH=128 + model_name_or_path=${input_model} + TASK_NAME=${dataset_location} + + python -u ./run_glue.py \ + --model_name_or_path ${model_name_or_path} \ + --task_name ${TASK_NAME} \ + --do_eval \ + --do_train \ + --max_seq_length ${MAX_SEQ_LENGTH} \ + --per_device_eval_batch_size ${batch_size} \ + --no_cuda \ + --output_dir ${tuned_checkpoint} \ + --output_model ${output_model} \ + --export \ + --export_dtype ${dtype} \ + --quant_format ${quant_format} \ + --output_dir ${tuned_checkpoint} \ + --overwrite_output_dir \ + ${extra_cmd} +} + +main "$@" diff --git a/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py new file mode 100644 index 00000000000..9658b694c51 --- /dev/null +++ b/examples/pytorch/nlp/huggingface_models/text-classification/export/fx/run_glue.py @@ -0,0 +1,593 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Finetuning the library models for sequence classification on GLUE.""" +# You can also adapt this script on your own text classification task. Pointers for this are left as comments. + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import Optional + +import datasets +import numpy as np +import transformers +from datasets import load_dataset, load_metric +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + PretrainedConfig, + Trainer, + TrainingArguments, + default_data_collator, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.10.0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +logger = logging.getLogger(__name__) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + task_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, + ) + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the training data."} + ) + validation_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the validation data."} + ) + + def __post_init__(self): + if self.task_name is not None: + self.task_name = self.task_name.lower() + if self.task_name not in task_to_keys.keys(): + raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) + elif self.dataset_name is not None: + pass + elif self.train_file is None or self.validation_file is None: + raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") + else: + train_extension = self.train_file.split(".")[-1] + assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." + validation_extension = self.validation_file.split(".")[-1] + assert ( + validation_extension == train_extension + ), "`validation_file` should have the same extension (csv or json) as `train_file`." + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + export: bool = field( + default=False, metadata={"help": "export PyTorch model into ONNX model."} + ) + export_dtype: str = field( + default="fp32", metadata={"help": "choose the data type [fp32/int8] of PyTorch model to be exported."} + ) + quant_format: str = field( + default="QDQ", metadata={"help": "choose the format [QDQ/QLinear] of int8 ONNX model exported."} + ) + output_model: str = field( + default="model.onnx", metadata={"help": "the name of exported model."} + ) + int8: bool = field( + default=False, metadata={"help": "use int8 model to get accuracy or benchmark"} + ) + benchmark: bool = field( + default=False, metadata={"help": "get benchmark instead of accuracy"} + ) + accuracy: bool = field( + default=False, metadata={"help": "get accuracy"} + ) + iters: int = field( + default=100, + metadata={ + "help": "The inference iterations to run for benchmark." + }, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the + # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named + # label if at least two columns are provided. + # + # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this + # single column. You can easily tweak this behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.task_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir) + elif data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir + ) + else: + # Loading a dataset from your local files. + # CSV/JSON training and evaluation files are needed. + data_files = {"train": data_args.train_file, "validation": data_args.validation_file} + + for key in data_files.keys(): + logger.info(f"load a local file for {key}: {data_files[key]}") + + if data_args.train_file.endswith(".csv"): + # Loading a dataset from local csv files + raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir) + else: + # Loading a dataset from local json files + raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Labels + if data_args.task_name is not None: + is_regression = data_args.task_name == "stsb" + if not is_regression: + label_list = raw_datasets["train"].features["label"].names + num_labels = len(label_list) + else: + num_labels = 1 + else: + # Trying to have good defaults here, don't hesitate to tweak to your needs. + is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] + if is_regression: + num_labels = 1 + else: + # A useful fast method: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique + label_list = raw_datasets["train"].unique("label") + label_list.sort() # Let's sort it for determinism + num_labels = len(label_list) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + if model_args.int8: + from neural_compressor.utils.load_huggingface import OptimizedModel + model = OptimizedModel.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + else: + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Preprocessing the raw_datasets + if data_args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[data_args.task_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] + if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: + sentence1_key, sentence2_key = "sentence1", "sentence2" + else: + if len(non_label_column_names) >= 2: + sentence1_key, sentence2_key = non_label_column_names[:2] + else: + sentence1_key, sentence2_key = non_label_column_names[0], None + + # Padding strategy + if data_args.pad_to_max_length: + padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + label_to_id = None + if ( + model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + and data_args.task_name is not None + and not is_regression + ): + # Some have all caps in their config, some don't. + label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): + label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} + else: + logger.warning( + f"Your model seems to have been trained with labels, but they don't match the dataset: " + f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}.\n" + f"Ignoring the model labels as a result." + ) + elif data_args.task_name is None and not is_regression: + label_to_id = {v: i for i, v in enumerate(label_list)} + + if label_to_id is not None: + model.config.label2id = label_to_id + model.config.id2label = {id: label for label, id in config.label2id.items()} + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) + + # Map labels to IDs (not necessary for GLUE tasks) + if label_to_id is not None and "label" in examples: + result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] + return result + + with training_args.main_process_first(desc="dataset map pre-processing"): + raw_datasets = raw_datasets.map( + preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache + ) + if training_args.do_train: + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = raw_datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + + # Log a few random samples from the training set: + if training_args.do_train: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Get the metric function + if data_args.task_name is not None: + metric = load_metric("glue", data_args.task_name) + else: + metric = load_metric("accuracy") + + # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a + # predictions and label_ids field) and has to return a dictionary string to float. + def compute_metrics(p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) + if data_args.task_name is not None: + result = metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. + if data_args.pad_to_max_length: + data_collator = default_data_collator + elif training_args.fp16: + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + else: + data_collator = None + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + compute_metrics=compute_metrics, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + eval_dataloader = trainer.get_eval_dataloader() + batch_size = eval_dataloader.batch_size + + def take_eval_steps(model, trainer, save_metrics=False): + trainer.model = model + metrics = trainer.evaluate() + if save_metrics: + trainer.save_metrics("eval", metrics) + logger.info("metrics keys: {}".format(metrics.keys())) + bert_task_acc_keys = ['eval_f1', 'eval_accuracy', 'eval_matthews_correlation', + 'eval_pearson', 'eval_mcc', 'eval_spearmanr'] + for key in bert_task_acc_keys: + if key in metrics.keys(): + throughput = metrics.get("eval_samples_per_second") + print('Batch size = %d' % batch_size) + print("Finally Eval {} Accuracy: {}".format(key, metrics[key])) + print("Latency: %.3f ms" % (1000 / throughput)) + print("Throughput: {} samples/sec".format(throughput)) + return metrics[key] + assert False, "No metric returned, Please check inference metric!" + + def eval_func(model): + return take_eval_steps(model, trainer) + + from neural_compressor.config import Torch2ONNXConfig + it = iter(eval_dataloader) + input = next(it) + input.pop('labels') + symbolic_names = {0: 'batch_size', 1: 'max_seq_len'} + dynamic_axes = {k: symbolic_names for k in input.keys()} + + if model_args.export and model_args.export_dtype == 'fp32': + from neural_compressor.model import Model + inc_model = Model(model) + fp32_onnx_config = Torch2ONNXConfig( + dtype=model_args.export_dtype, + opset_version=14, + example_inputs=tuple(input.values()), + input_names=list(input.keys()), + output_names=['labels'], + dynamic_axes=dynamic_axes, + ) + inc_model.export(model_args.output_model, fp32_onnx_config) + + # optimize and quantize with Neural Compressor + if model_args.export_dtype == 'int8': + from neural_compressor.quantization import fit + from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion + tuning_criterion = TuningCriterion( + strategy="mse_v2", + strategy_kwargs={"confidence_batches": 1}, + max_trials=600, + ) + conf = PostTrainingQuantConfig( + approach="static", + tuning_criterion=tuning_criterion, + calibration_sampling_size=[300], + ) + q_model = fit(model, conf=conf, calib_dataloader=eval_dataloader, eval_func=eval_func) + from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream + save_for_huggingface_upstream(q_model, tokenizer, training_args.output_dir) + + int8_onnx_config = Torch2ONNXConfig( + dtype=model_args.export_dtype, + opset_version=14, + quant_format=model_args.quant_format, + example_inputs=tuple(input.values()), + input_names=list(input.keys()), + output_names=['labels'], + dynamic_axes=dynamic_axes, + ) + q_model.export(model_args.output_model, int8_onnx_config) + return + + if model_args.benchmark: + from neural_compressor.config import BenchmarkConfig + from neural_compressor import benchmark + b_conf = BenchmarkConfig(warmup=5, + iteration=model_args.iters, + cores_per_instance=4, + num_of_instance=1) + benchmark.fit(model, b_conf, b_dataloader=eval_dataloader) + elif model_args.accuracy: + eval_func(model) + + + if training_args.push_to_hub: + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"} + if data_args.task_name is not None: + kwargs["language"] = "en" + kwargs["dataset_tags"] = "glue" + kwargs["dataset_args"] = data_args.task_name + kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}" + + trainer.push_to_hub(**kwargs) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/neural_compressor/conf/config.py b/neural_compressor/conf/config.py index b9bd6671ce2..857791c9cf7 100644 --- a/neural_compressor/conf/config.py +++ b/neural_compressor/conf/config.py @@ -1421,8 +1421,7 @@ def map_pyconfig_to_cfg(self, pythonic_config): if pythonic_config.quantization.strategy_kwargs: st_kwargs = pythonic_config.quantization.strategy_kwargs for st_key in ['sigopt_api_token', 'sigopt_project_id', 'sigopt_experiment_name', \ - 'accuracy_weight', 'latency_weight', 'hawq_v2_loss']: - + 'accuracy_weight', 'latency_weight', 'hawq_v2_loss', 'confidence_batches']: if st_key in st_kwargs: st_val = st_kwargs[st_key] mapping.update({'tuning.strategy.' + st_key: st_val}) diff --git a/neural_compressor/experimental/export/torch2onnx.py b/neural_compressor/experimental/export/torch2onnx.py index 3995ce98707..f8d231d5342 100644 --- a/neural_compressor/experimental/export/torch2onnx.py +++ b/neural_compressor/experimental/export/torch2onnx.py @@ -681,19 +681,32 @@ def torch_to_fp32_onnx( do_constant_folding (bool, optional): do constant folding or not. Defaults to True. verbose (bool, optional): dump verbose or not. Defaults to True. """ - if input_names: - example_input_names = input_names - else: - example_input_names = ['input'] - if isinstance(example_inputs, dict) or isinstance(example_inputs, UserDict): - example_input_names = list(example_inputs.keys()) + if input_names is None and \ + (isinstance(example_inputs, dict) or isinstance(example_inputs, UserDict)): + input_names = list(example_inputs.keys()) + example_inputs = list(example_inputs.values()) + # match input_names with inspected input_order, especailly for bert in hugginface. + if input_names and len(input_names) > 1: + import inspect + input_order = inspect.signature(fp32_model.forward).parameters.keys() + flag = [name in input_order for name in input_names] # whether should be checked + if all(flag): + new_input_names = [] + new_example_inputs = [] + for name in input_order: + if name in input_names: + new_input_names.append(name) + id = input_names.index(name) + new_example_inputs.append(example_inputs[id]) + input_names = new_input_names + example_inputs = new_example_inputs torch.onnx.export( fp32_model, input2tuple(example_inputs), save_path, opset_version=opset_version, - input_names=example_input_names, + input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, do_constant_folding=do_constant_folding, @@ -747,6 +760,7 @@ def torch_to_int8_onnx( else: op_types_to_quantize=['MatMul', 'Gemm', 'Gather', 'Conv'] + quant_format = quant_format.upper() if quant_format == 'QDQ' and opset_version < 13: # pragma: no cover opset_version = 13 logger.warning("QDQ format requires opset_version >= 13, " +