diff --git a/README.md b/README.md index 5fbad1d0434..e7ed35bf5cc 100644 --- a/README.md +++ b/README.md @@ -68,67 +68,52 @@ pip install "neural-compressor>=2.3" "transformers>=4.34.0" torch torchvision ``` After successfully installing these packages, try your first quantization program. -### Weight-Only Quantization (LLMs) -Following example code demonstrates Weight-Only Quantization on LLMs, it supports Intel CPU, Intel Gaudi2 AI Accelerator, Nvidia GPU, best device will be selected automatically. +### [FP8 Quantization](./examples/3.x_api/pytorch/cv/fp8_quant/) +Following example code demonstrates FP8 Quantization, it is supported by Intel Gaudi2 AI Accelerator. To try on Intel Gaudi2, docker image with Gaudi Software Stack is recommended, please refer to following script for environment setup. More details can be found in [Gaudi Guide](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#launch-docker-image-that-was-built). ```bash # Run a container with an interactive shell -docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest - -# Install the optimum-habana -pip install --upgrade-strategy eager optimum[habana] - -# Install INC/auto_round -pip install neural-compressor auto_round +docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest ``` Run the example: ```python -from transformers import AutoModel, AutoTokenizer - -from neural_compressor.config import PostTrainingQuantConfig -from neural_compressor.quantization import fit -from neural_compressor.adaptor.torch_utils.auto_round import get_dataloader - -model_name = "EleutherAI/gpt-neo-125m" -float_model = AutoModel.from_pretrained(model_name) -tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) -dataloader = get_dataloader(tokenizer, seqlen=2048) - -woq_conf = PostTrainingQuantConfig( - approach="weight_only", - op_type_dict={ - ".*": { # match all ops - "weight": { - "dtype": "int", - "bits": 4, - "algorithm": "AUTOROUND", - }, - } - }, +from neural_compressor.torch.quantization import ( + FP8Config, + prepare, + convert, ) -quantized_model = fit(model=float_model, conf=woq_conf, calib_dataloader=dataloader) +import torchvision.models as models + +model = models.resnet18() +qconfig = FP8Config(fp8_config="E4M3") +model = prepare(model, qconfig) +# customer defined calibration +calib_func(model) +model = convert(model) ``` -**Note:** -To try INT4 model inference, please directly use [Intel Extension for Transformers](https://github.com/intel/intel-extension-for-transformers), which leverages Intel Neural Compressor for model quantization. +### [Weight-Only Quantization (LLMs)](./examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/) -### Static Quantization (Non-LLMs) +Following example code demonstrates Weight-Only Quantization on LLMs, it supports Intel CPU, Intel Gaudi2 AI Accelerator, Nvidia GPU, best device will be selected automatically. ```python -from torchvision import models +from neural_compressor.torch.quantization import prepare, convert, AutoRoundConfig -from neural_compressor.config import PostTrainingQuantConfig -from neural_compressor.data import DataLoader, Datasets -from neural_compressor.quantization import fit +model_name = "EleutherAI/gpt-neo-125m" +model = AutoModel.from_pretrained(model_name) -float_model = models.resnet18() -dataset = Datasets("pytorch")["dummy"](shape=(1, 3, 224, 224)) -calib_dataloader = DataLoader(framework="pytorch", dataset=dataset) -static_quant_conf = PostTrainingQuantConfig() -quantized_model = fit(model=float_model, conf=static_quant_conf, calib_dataloader=calib_dataloader) +quant_config = AutoRoundConfig() +model = prepare(model, quant_config) +# customer defined calibration +run_fn(model) # calibration +model = convert(model) ``` +**Note:** + +To try INT4 model inference, please directly use [Intel Extension for Transformers](https://github.com/intel/intel-extension-for-transformers), which leverages Intel Neural Compressor for model quantization. + ## Documentation @@ -154,12 +139,13 @@ quantized_model = fit(model=float_model, conf=static_quant_conf, calib_dataloade - + - + + diff --git a/docs/3x/PT_FP8Quant.md b/docs/3x/PT_FP8Quant.md new file mode 100644 index 00000000000..a0ed3352e8e --- /dev/null +++ b/docs/3x/PT_FP8Quant.md @@ -0,0 +1,113 @@ +FP8 Quantization +======= + +1. [Introduction](#introduction) +2. [Supported Parameters](#supported-parameters) +3. [Get Start with FP8 Quantization](#get-start-with-fp8-quantization) +4. [Examples](#examples) + +## Introduction + +Float point 8 (FP8) is a promising data type for low precision quantization which provides a data distribution that is completely different from INT8 and it's shown as below. + +
+ +
+ +Intel Gaudi2, also known as HPU, provides this data type capability for low precision quantization, which includes `E4M3` and `E5M2`. For more information about these two data type, please refer to [link](https://arxiv.org/abs/2209.05433). + +Intel Neural Compressor provides general quantization APIs to leverage HPU FP8 capability. with simple with lower memory usage and lower compute cost, 8 bit model + +## Supported Parameters + + +
OverviewStatic Quantization Dynamic QuantizationStatic Quantization Smooth Quantization
Weight-Only QuantizationWeight-Only QuantizationFP8 Quantization MX Quantization Mixed Precision
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
AttributeDescriptionValues
fp8_configThe target data type of FP8 quantization.E4M3 (default) - As Fig. 2
E5M2 - As Fig. 1.
hp_dtypeThe high precision data type of non-FP8 operators.bf16 (default) - torch.bfloat16
fp16 - torch.float16.
fp32 - torch.float32.
observerThe observer to measure the statistics.maxabs (default), saves all tensors to files.
allowlistList of nn.Module names or types to quantize. When setting an empty list, all the supported modules will be quantized by default. See Supported Modules. Not setting the list at all is not recommended as it will set the allowlist to these modules only: torch.nn.Linear, torch.nn.Conv2d, and BMM.Default = {'names': [], 'types': FP8_WHITE_LIST}
blocklistList of nn.Module names or types not to quantize. Defaults to empty list, so you may omit it from the config file.Default = {'names': [], 'types': ()}
modeThe mode, measure or quantize, to run HQT with.MEASURE - Measure statistics of all modules and emit the results to dump_stats_path.
QUANTIZE - Quantize and run the model according to the provided measurements.
AUTO (default) - Select from [MEASURE, QUANTIZE] automatically.
dump_stats_pathThe path to save and load the measurements. The path is created up until the level before last "/". The string after the last / will be used as prefix to all the measurement files that will be created.Default = "./hqt_output/measure"
scale_methodThe method for calculating the scale from the measurement.- without_scale - Convert to/from FP8 without scaling.
- unit_scale - Always use scale of 1.
- maxabs_hw (default) - Scale is calculated to stretch/compress the maxabs measurement to the full-scale of FP8 and then aligned to the corresponding HW accelerated scale.
- maxabs_pow2 - Scale is calculated to stretch/compress the maxabs measurement to the full-scale of FP8 and then rounded to the power of 2.
- maxabs_hw_opt_weight - Scale of model params (weights) is chosen as the scale that provides minimal mean-square-error between quantized and non-quantized weights, from all possible HW accelerated scales. Scale of activations is calculated the same as maxabs_hw.
- act_maxabs_pow2_weights_pcs_opt_pow2 - Scale of model params (weights) is calculated per-channel of the params tensor. The scale per-channel is calculated the same as maxabs_hw_opt_weight. Scale of activations is calculated the same as maxabs_pow2.
- act_maxabs_hw_weights_pcs_maxabs_pow2 - Scale of model params (weights) is calculated per-channel of the params tensor. The scale per-channel is calculated the same as maxabs_pow2. Scale of activations is calculated the same as maxabs_hw.
measure_excludeIf this attribute is not defined, the default is OUTPUT. Since most models do not require measuring output tensors, you can exclude it to speed up the measurement process.NONE - All tensors are measured.
OUTPUT (default) - Excludes measurement of output tensors.
+ +## Get Start with FP8 Quantization + +### Demo Usage + +```python +from neural_compressor.torch.quantization import ( + FP8Config, + prepare, + convert, +) +import torchvision.models as models + +model = models.resnet18() +qconfig = FP8Config(fp8_config="E4M3") +model = prepare(model, qconfig) +# customer defined calibration +calib_func(model) +model = convert(model) +``` + +## Examples + +| Task | Example | +|----------------------|---------| +| Computer Vision (CV) | [Link](../../examples/3.x_api/pytorch/cv/fp8_quant/) | +| Large Language Model (LLM) | [Link](https://github.com/HabanaAI/optimum-habana-fork/tree/habana-main/examples/text-generation#running-with-fp8) | + +> Note: For LLM, Optimum-habana provides higher performance based on modified modeling files, so here the Link of LLM goes to Optimum-habana, which utilize Intel Neural Compressor for FP8 quantization internally. diff --git a/examples/.config/model_params_pytorch_3x.json b/examples/.config/model_params_pytorch_3x.json index e38749e2ef6..7b526005223 100644 --- a/examples/.config/model_params_pytorch_3x.json +++ b/examples/.config/model_params_pytorch_3x.json @@ -140,6 +140,13 @@ "main_script": "main.py", "batch_size": 1 }, + "resnet18_fp8_static":{ + "model_src_dir": "cv/fp8_quant", + "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", + "input_model": "", + "main_script": "main.py", + "batch_size": 1 + }, "opt_125m_pt2e_static":{ "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e", "dataset_location": "", diff --git a/examples/3.x_api/pytorch/cv/fp8_quant/README.md b/examples/3.x_api/pytorch/cv/fp8_quant/README.md new file mode 100644 index 00000000000..ebad25f9f05 --- /dev/null +++ b/examples/3.x_api/pytorch/cv/fp8_quant/README.md @@ -0,0 +1,28 @@ +# ImageNet FP8 Quantization + +This implements FP8 quantization of popular model architectures, such as ResNet on the ImageNet dataset, which is supported by Intel Gaudi2 AI Accelerator. + +## Requirements + +To try on Intel Gaudi2, docker image with Gaudi Software Stack is recommended, please refer to following script for environment setup. More details can be found in [Gaudi Guide](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#launch-docker-image-that-was-built). +```bash +# Run a container with an interactive shell +docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest +``` + +- Install requirements +- `pip install -r requirements.txt` +- Download the ImageNet dataset from http://www.image-net.org/ + - Then, move and extract the training and validation images to labeled subfolders, using [the following shell script](extract_ILSVRC.sh) + +## Quantizaiton + +To quant a model and validate accaracy, run `main.py` with the desired model architecture and the path to the ImageNet dataset: + +```bash +python main.py --pretrained -t -a resnet50 -b 30 /path/to/imagenet +``` +or +```bash +bash run_quant.sh --input_model=resnet50 --dataset_location=/path/to/imagenet +``` diff --git a/examples/3.x_api/pytorch/cv/fp8_quant/extract_ILSVRC.sh b/examples/3.x_api/pytorch/cv/fp8_quant/extract_ILSVRC.sh new file mode 100644 index 00000000000..3ec05e8f328 --- /dev/null +++ b/examples/3.x_api/pytorch/cv/fp8_quant/extract_ILSVRC.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# +# script to extract ImageNet dataset +# ILSVRC2012_img_train.tar (about 138 GB) +# ILSVRC2012_img_val.tar (about 6.3 GB) +# make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory +# +# Adapted from: +# https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md +# https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4 +# +# imagenet/train/ +# ├── n01440764 +# │ ├── n01440764_10026.JPEG +# │ ├── n01440764_10027.JPEG +# │ ├── ...... +# ├── ...... +# imagenet/val/ +# ├── n01440764 +# │ ├── ILSVRC2012_val_00000293.JPEG +# │ ├── ILSVRC2012_val_00002138.JPEG +# │ ├── ...... +# ├── ...... +# +# +# Make imagnet directory +# +mkdir imagenet +# +# Extract the training data: +# +# Create train directory; move .tar file; change directory +mkdir imagenet/train && mv ILSVRC2012_img_train.tar imagenet/train/ && cd imagenet/train +# Extract training set; remove compressed file +tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar +# +# At this stage imagenet/train will contain 1000 compressed .tar files, one for each category +# +# For each .tar file: +# 1. create directory with same name as .tar file +# 2. extract and copy contents of .tar file into directory +# 3. remove .tar file +find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done +# +# This results in a training directory like so: +# +# imagenet/train/ +# ├── n01440764 +# │ ├── n01440764_10026.JPEG +# │ ├── n01440764_10027.JPEG +# │ ├── ...... +# ├── ...... +# +# Change back to original directory +cd ../.. +# +# Extract the validation data and move images to subfolders: +# +# Create validation directory; move .tar file; change directory; extract validation .tar; remove compressed file +mkdir imagenet/val && mv ILSVRC2012_img_val.tar imagenet/val/ && cd imagenet/val && tar -xvf ILSVRC2012_img_val.tar && rm -f ILSVRC2012_img_val.tar +# get script from soumith and run; this script creates all class directories and moves images into corresponding directories +wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash +# +# This results in a validation directory like so: +# +# imagenet/val/ +# ├── n01440764 +# │ ├── ILSVRC2012_val_00000293.JPEG +# │ ├── ILSVRC2012_val_00002138.JPEG +# │ ├── ...... +# ├── ...... +# +# +# Check total files after extract +# +# $ find train/ -name "*.JPEG" | wc -l +# 1281167 +# $ find val/ -name "*.JPEG" | wc -l +# 50000 +# \ No newline at end of file diff --git a/examples/3.x_api/pytorch/cv/fp8_quant/main.py b/examples/3.x_api/pytorch/cv/fp8_quant/main.py new file mode 100644 index 00000000000..dfa7515343c --- /dev/null +++ b/examples/3.x_api/pytorch/cv/fp8_quant/main.py @@ -0,0 +1,391 @@ +import argparse +import os +import random +import shutil +import time +import warnings +import sys + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models +from neural_compressor.torch.quantization import ( + FP8Config, + prepare, + convert, +) +import habana_frameworks.torch.core as htcore + + +model_names = models.list_models(module=models) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('-t', '--tune', dest='tune', action='store_true', + help='tune best int8 model on calibration dataset') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--ppn', default=1, type=int, + help='number of processes on each node of distributed training') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') +parser.add_argument("--calib_iters", default=10, type=int, + help="For calibration only.") +parser.add_argument('-i', "--iter", default=0, type=int, + help='For accuracy measurement only.') +parser.add_argument('-w', "--warmup_iter", default=5, type=int, + help='For benchmark measurement only.') +parser.add_argument('--performance', dest='performance', action='store_true', + help='run benchmark') +parser.add_argument('-r', "--accuracy", dest='accuracy', action='store_true', + help='For accuracy measurement only.') +parser.add_argument("--tuned_checkpoint", default='./saved_results', type=str, metavar='PATH', + help='path to checkpoint tuned by Neural Compressor (default: ./)') +parser.add_argument('--int8', dest='int8', action='store_true', + help='run benchmark') +parser.add_argument('--device', default='hpu', type=str, + help='use hpu device for fp8 quantization') + +best_acc1 = 0 + + +def main(): + args = parser.parse_args() + + if 'mobilenet' in args.arch: + import torchvision.models.quantization as models + else: + import torchvision.models as models + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + + if args.pretrained: + print("=> using pre-trained model '{}'".format(args.arch)) + model = models.__dict__[args.arch](pretrained=True) + else: + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss() + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True, sampler=None) + + val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + def eval_func(model): + accu = validate(val_loader, model, criterion, args) + return float(accu) + + if args.tune: + qconfig = FP8Config(fp8_config="E4M3") + model = prepare(model, qconfig) + + # Calibrate + # model is moved to HPU device automatically after preparing + with torch.no_grad(): + for i, (images, target) in enumerate(train_loader): + print("Calibrating batch:", i) + if i == args.calib_iters: + break + images = images.to(args.device) + model(images) + htcore.mark_step() + + model = convert(model) + eval_func(model) + # The saving and loading of fp8 quantization are planned in the next release. + + if args.performance or args.accuracy: + model.eval() + if args.int8: + from neural_compressor.utils.pytorch import load + new_model = load(os.path.abspath(os.path.expanduser(args.tuned_checkpoint)), + model, + dataloader=val_loader) + else: + new_model = model + if args.performance: + from neural_compressor.config import BenchmarkConfig + from neural_compressor import benchmark + b_conf = BenchmarkConfig(warmup=5, + iteration=args.iter, + cores_per_instance=4, + num_of_instance=1) + benchmark.fit(new_model, b_conf, b_dataloader=val_loader) + if args.accuracy: + validate(val_loader, new_model, criterion, args) + return + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, + top5, prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + input = input.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1[0], input.size(0)) + top5.update(acc5[0], input.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.print(i) + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + for i, (input, target) in enumerate(val_loader): + if i >= args.warmup_iter: + start = time.time() + input = input.to(args.device) + target = target.to(args.device) + if args.gpu is not None: + input = input.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1[0], input.size(0)) + top5.update(acc5[0], input.size(0)) + + # measure elapsed time + if i >= args.warmup_iter: + batch_time.update(time.time() - start) + + if i % args.print_freq == 0: + progress.print(i) + + if args.iter > 0 and i >= (args.warmup_iter + args.iter - 1): + break + + print('Batch size = %d' % args.batch_size) + print('Accuracy: {top1:.5f} Accuracy@5 {top5:.5f}' + .format(top1=(top1.avg / 100), top5=(top5.avg / 100))) + + return top1.avg/100 + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, *meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def print(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/examples/3.x_api/pytorch/cv/fp8_quant/requirements.txt b/examples/3.x_api/pytorch/cv/fp8_quant/requirements.txt new file mode 100644 index 00000000000..ebd3df6ae7a --- /dev/null +++ b/examples/3.x_api/pytorch/cv/fp8_quant/requirements.txt @@ -0,0 +1,3 @@ +torch +torchvision +neural-compressor \ No newline at end of file diff --git a/examples/3.x_api/pytorch/cv/fp8_quant/run_quant.sh b/examples/3.x_api/pytorch/cv/fp8_quant/run_quant.sh new file mode 100644 index 00000000000..4d0047cf2d1 --- /dev/null +++ b/examples/3.x_api/pytorch/cv/fp8_quant/run_quant.sh @@ -0,0 +1,53 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_tuning + +} + +# init params +function init_params { + output_model=saved_results + for var in "$@" + do + case $var in + --topology=*) + topology=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --output_model=*) + output_model=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + +# run_tuning +function run_tuning { + if [ "${topology}" = "resnet18_fp8_static" ]; then + input_model="resnet18" + output_dir="saved_results" + fi + python main.py \ + --pretrained \ + -t \ + -a ${input_model} \ + -b 30 \ + --tuned_checkpoint ${output_model} \ + ${dataset_location} +} + +main "$@" diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 09032360cc0..d54e2e6515b 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -379,8 +379,10 @@ def to_json_file(self, filename): Args: filename (str): The path to save the JSON file. """ - # Implementation details omitted for brevity - pass + config_dict = self.to_dict() + with open(filename, "w", encoding="utf-8") as file: + json.dump(config_dict, file, indent=4) + logger.info("Dump the config into %s.", filename) def to_json_string(self, use_diff: bool = False) -> str: """Serializes this instance to a JSON string. diff --git a/neural_compressor/torch/algorithms/fp8_quant/common.py b/neural_compressor/torch/algorithms/fp8_quant/common.py index 9bce5b39a37..163509a6048 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/common.py +++ b/neural_compressor/torch/algorithms/fp8_quant/common.py @@ -91,8 +91,8 @@ def restore_patched_module(patched_model): class_name_org = ( getattr(patched_mod, "class_name_org", None) or patched_mod.__class__.__name__.split("Patched")[-1] ) + patched_mod.__dict__.pop("forward", None) origin_mod = helper_mods[class_name_org](patched_mod) - origin_mod.forward = patched_mod.forward_orig setattr(parent, name, origin_mod) diff --git a/test/3x/torch/quantization/fp8_quant/test_fp8_static_quant.py b/test/3x/torch/quantization/fp8_quant/test_fp8_static_quant.py index eb71a550782..88735312713 100644 --- a/test/3x/torch/quantization/fp8_quant/test_fp8_static_quant.py +++ b/test/3x/torch/quantization/fp8_quant/test_fp8_static_quant.py @@ -3,9 +3,10 @@ import pytest import torch +import torchvision import transformers -from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedLinear +from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedLinear, PatchedConv2d from neural_compressor.torch.quantization import ( FP8Config, convert, @@ -31,24 +32,57 @@ def setup_class(self): "hf-internal-testing/tiny-random-GPTJForCausalLM", device_map="cpu", ) - self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long) + self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long).to("hpu") + self.resnet18 = torchvision.models.resnet18(pretrained=True) + self.cv_dummy_inputs = torch.randn([1,3,224,224]).to("hpu") def teardown_class(self): shutil.rmtree("test_ouputs", ignore_errors=True) - def test_one_step_quant(self): + def test_one_step_quant_nlp(self): model = copy.deepcopy(self.tiny_gptj) + model.to('hpu') + fp32_out = model(self.example_inputs)[0] qconfig = FP8Config(fp8_config="E4M3") model = prepare(model, qconfig) assert isinstance(model.transformer.h[0].attn.k_proj, PatchedLinear), "k_proj is not prepared." calib_func(model) model = convert(model) + fp8_out = model(self.example_inputs)[0] assert isinstance(model.transformer.h[0].attn.k_proj, PatchedLinear), "k_proj is not quantized." assert ( model.transformer.h[0].attn.k_proj.quant_input.lp_dtype == torch.float8_e4m3fn ), "k_proj input dtype is not torch.float8_e4m3fn." + assert (fp32_out != fp8_out).any(), "FP32 output should be different with FP8 output" + print((fp32_out - fp8_out).abs().max()) + assert torch.allclose(fp32_out, fp8_out, atol=0.04), "Accuracy gap atol > 0.04 is unexpected." - def test_two_step_quant(self): + # @pytest.mark.skipif(not is_hpex_available(), reason="HPU environment is required!") + def test_one_step_quant_cv(self): + model = copy.deepcopy(self.resnet18) + model.to('hpu') + fp32_out = model(self.cv_dummy_inputs) + # model.to('cpu') + qconfig = FP8Config(fp8_config="E4M3") + model = prepare(model, qconfig) + assert model.fc.weight.device.type == "hpu", "model is not mapped to HPU." + assert ( + isinstance(model.fc, PatchedLinear) and + isinstance(model.conv1, PatchedConv2d) + ), "model is not prepared." + # calibration + model(self.cv_dummy_inputs) + model = convert(model) + fp8_out = model(self.cv_dummy_inputs) + assert ( + isinstance(model.fc, PatchedLinear) and + isinstance(model.conv1, PatchedConv2d) and + model.fc.quant_input.lp_dtype == torch.float8_e4m3fn and + model.conv1.quant_input.lp_dtype == torch.float8_e4m3fn + ), "model is not quantized to torch.float8_e4m3fn." + assert (fp32_out != fp8_out).any(), "FP32 output should be different with FP8 output" + + def test_two_step_quant_nlp(self): # step 1: measurement model = copy.deepcopy(self.tiny_gptj) config = FP8Config.from_json_file("test_fp8_jsons/test_measure.json") @@ -64,3 +98,27 @@ def test_two_step_quant(self): assert ( model.transformer.h[0].attn.k_proj.quant_input.lp_dtype == torch.float8_e4m3fn ), "k_proj input dtype is not torch.float8_e4m3fn." + + def test_two_step_quant_cv(self): + # step 1: measurement + model = copy.deepcopy(self.resnet18) + config = FP8Config.from_json_file("test_fp8_jsons/test_measure.json") + model = prepare(model, config) + fp32_out = model(self.cv_dummy_inputs) + finalize_calibration(model) + assert ( + isinstance(model.fc, PatchedLinear) and + isinstance(model.conv1, PatchedConv2d) + ), "model is not prepared." + # step 2: quantize based on measurement + model = copy.deepcopy(self.resnet18) + config = FP8Config.from_json_file("test_fp8_jsons/test_hw_quant.json") + model = convert(model, config) + fp8_out = model(self.cv_dummy_inputs) + assert ( + isinstance(model.fc, PatchedLinear) and + isinstance(model.conv1, PatchedConv2d) and + model.fc.quant_input.lp_dtype == torch.float8_e4m3fn and + model.conv1.quant_input.lp_dtype == torch.float8_e4m3fn + ), "model is not quantized to torch.float8_e4m3fn." + assert (fp32_out != fp8_out).any(), "FP32 output should be different with FP8 output" diff --git a/test/3x/torch/requirements.txt b/test/3x/torch/requirements.txt index c17e22d6f77..6605132ff6f 100644 --- a/test/3x/torch/requirements.txt +++ b/test/3x/torch/requirements.txt @@ -6,3 +6,4 @@ prettytable psutil pytest transformers +torchvision