Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNCF]: Add INT8 weight compression conformance test for Tinyllama-1.1b PyTorch model #2636

Merged
merged 41 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2d6fbe5
feat: Added to the test scope
AdiKsOnDev Apr 9, 2024
52e8180
feat: Added torch backend support
AdiKsOnDev Apr 12, 2024
fd48363
Merge branch 'openvinotoolkit:develop' into develop
AdiKsOnDev Apr 16, 2024
c024803
fix: Moved int8 conversion in _validate()
AdiKsOnDev Apr 18, 2024
6405c4e
git: Merge branch 'develop' of github.com:AdiKsOnDev/nncf into develop
AdiKsOnDev Apr 18, 2024
f48c148
fix: Returned initial implementation of _validate()
AdiKsOnDev Apr 22, 2024
f9505e4
chore: Temporary dummy data
AdiKsOnDev Apr 22, 2024
2bc73ec
fix: Model Preparation for TORCH backend
AdiKsOnDev Apr 22, 2024
927c38f
fix: Removed unsupported parameters for INT8
AdiKsOnDev Apr 22, 2024
f008103
chore: Comment on important addition
AdiKsOnDev Apr 22, 2024
eeade47
feat: Added correct metric value according to @aleksu52
AdiKsOnDev Apr 22, 2024
fc05eed
fix: Mode accurate check for the INT8 compression mode
AdiKsOnDev Apr 23, 2024
4aefa0d
feat: Problematic code for @aleksu52 to reproduce
AdiKsOnDev Apr 23, 2024
737c1a7
feat: Use AutoModelForCausalLM for TORCH models
AdiKsOnDev Apr 24, 2024
8066b76
fix: Added model specific parameters during preparation
AdiKsOnDev Apr 24, 2024
512aa63
Merge branch 'openvinotoolkit:develop' into develop
AdiKsOnDev Apr 24, 2024
0041998
refactor: Make a tokenizer during model preparation
AdiKsOnDev Apr 24, 2024
3a61ccf
feat: Tokenize an input string (Temporary) to feed in torch model
AdiKsOnDev Apr 24, 2024
ea0c4c4
fix: Added torch_dtype parameter to the model
AdiKsOnDev Apr 24, 2024
c346100
chore: Removed unnecessary compression parameters
AdiKsOnDev Apr 24, 2024
1cfccf9
refactor: Line spacing, preprocessor usage
AdiKsOnDev Apr 25, 2024
88dc901
Merge branch 'openvinotoolkit:develop' into develop
AdiKsOnDev Apr 26, 2024
5deba30
fix: Removing convert_model()
AdiKsOnDev Apr 27, 2024
40c5686
fix: The pipeline now runs for TORCH models
AdiKsOnDev Apr 27, 2024
d3989be
fix: Using model_hf for validation
AdiKsOnDev Apr 28, 2024
43aec31
fix: Changed the reference metric value
AdiKsOnDev Apr 28, 2024
a85ded2
refactor: Pre-Commit changes
AdiKsOnDev Apr 28, 2024
28af569
fix: Returned the original checks for int4/int8 values
AdiKsOnDev Apr 30, 2024
c3d5e2d
Merge branch 'openvinotoolkit:develop' into develop
AdiKsOnDev Apr 30, 2024
a72ae7e
chore: Pre-Commit changes
AdiKsOnDev Apr 30, 2024
2f6f69c
git: Merge main branch
AdiKsOnDev Apr 30, 2024
d90b356
Merge branch 'develop' into develop
AdiKsOnDev Apr 30, 2024
7d328c3
refactor: Pre-Commit Changes
AdiKsOnDev Apr 30, 2024
7e50cfa
fix: Removed the debugging line
AdiKsOnDev May 1, 2024
7c31d3d
fix: Corrected reference data for TORCH backend
AdiKsOnDev May 2, 2024
6899097
refactor: Code made cleaner
AdiKsOnDev May 2, 2024
86e91f9
fix: Utilized wikitext for TORCH models as well
AdiKsOnDev May 2, 2024
7f32430
feat: Implemented get_num_compressed
AdiKsOnDev May 2, 2024
7729867
fix: Dumping the fp32 model correctly
AdiKsOnDev May 2, 2024
70cd912
chore: Removed unneccesary model wrapping
AdiKsOnDev May 2, 2024
e5db8cc
fix: Changed _validate to match the modified pipeline
AdiKsOnDev May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/post_training/data/wc_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@ tinyllama_data_aware_awq_scale_estimation_stateful_backend_OV:
metric_value: 0.83795
num_int4: 188
num_int8: 124
tinyllama_int8_data_free_backend_TORCH:
metric_value: 0.95624
num_int4: 0
num_int8: 312
9 changes: 9 additions & 0 deletions tests/post_training/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,15 @@
"params": {"is_stateful": True},
"backends": [BackendType.OV],
},
{
"reported_name": "tinyllama_int8_data_free",
"model_id": "tinyllama/tinyllama-1.1b-step-50k-105b",
"pipeline_cls": LMWeightCompression,
"compression_params": {
"mode": CompressWeightsMode.INT8_ASYM,
},
"backends": [BackendType.TORCH],
},
]


Expand Down
111 changes: 73 additions & 38 deletions tests/post_training/pipelines/lm_weight_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@

import numpy as np
import openvino as ov
import torch
from datasets import load_dataset
from memory_profiler import memory_usage
from optimum.exporters.openvino.convert import export_from_model
from optimum.intel.openvino import OVModelForCausalLM
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from whowhatbench import Evaluator

Expand Down Expand Up @@ -72,20 +75,36 @@ class LMWeightCompression(BaseTestPipeline):

def prepare_model(self) -> None:
is_stateful = self.params.get("is_stateful", False)
if is_stateful:
self.fp32_model_dir = self.fp32_model_dir.parent / (self.fp32_model_dir.name + "_sf")
if not (self.fp32_model_dir / self.OV_MODEL_NAME).exists():
# export by model_id
self.model_hf = OVModelForCausalLM.from_pretrained(
self.model_id, export=True, load_in_8bit=False, compile=False, stateful=is_stateful

# load model
if self.backend == BackendType.TORCH:
if is_stateful:
raise RuntimeError(f"is_stateful={is_stateful} is not supported for PyTorch backend.")

self.model_hf = AutoModelForCausalLM.from_pretrained(
self.model_id, torch_dtype=torch.float32, device_map="cpu"
)
self._dump_model_fp32()
self.model = self.model_hf
elif self.backend == BackendType.OV:
if is_stateful:
self.fp32_model_dir = self.fp32_model_dir.parent / (self.fp32_model_dir.name + "_sf")
if not (self.fp32_model_dir / self.OV_MODEL_NAME).exists():
# export by model_id
self.model_hf = OVModelForCausalLM.from_pretrained(
self.model_id, export=True, load_in_8bit=False, compile=False, stateful=is_stateful
)
else:
# no export, load from IR. Applicable for sequential run of test cases in local environment.
self.model_hf = OVModelForCausalLM.from_pretrained(
self.fp32_model_dir, trust_remote_code=True, load_in_8bit=False, compile=False, stateful=is_stateful
)
self.model = self.model_hf.model
else:
# no export, load from IR. Applicable for sequential run of test cases in local environment.
self.model_hf = OVModelForCausalLM.from_pretrained(
self.fp32_model_dir, trust_remote_code=True, load_in_8bit=False, compile=False, stateful=is_stateful
)
self.model = self.model_hf.model
raise RuntimeError(f"backend={self.backend.value} is not supported.")

# dump FP32 model
if not (self.fp32_model_dir / self.OV_MODEL_NAME).exists():
self._dump_model_fp32()

def prepare_preprocessor(self) -> None:
self.preprocessor = AutoTokenizer.from_pretrained(self.model_id)
Expand All @@ -108,36 +127,40 @@ def transform_fn(data, max_tokens=128):
inputs["attention_mask"] = attention_mask
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1

# The magic forms KV cache as model inputs
batch_size = input_ids.shape[0]
for input_name in self.model_hf.key_value_input_names:
model_inputs = self.model.input(input_name)
shape = model_inputs.get_partial_shape()
shape[0] = batch_size
if shape[2].is_dynamic:
shape[2] = 0
else:
shape[1] = 0
inputs[input_name] = ov.Tensor(model_inputs.get_element_type(), shape.get_shape())

inputs["position_ids"] = position_ids

# initialize the rest of inputs (e.g. beam_idx for stateful models)
for val in self.model.inputs:
name = val.any_name
if name in inputs:
continue
shape = list(val.partial_shape.get_min_shape())
shape[0] = batch_size
inputs[name] = np.zeros(shape)
if self.backend == BackendType.OV:
# The magic forms KV cache as model inputs
batch_size = input_ids.shape[0]
for input_name in self.model_hf.key_value_input_names:
model_inputs = self.model.input(input_name)
shape = model_inputs.get_partial_shape()
shape[0] = batch_size
if shape[2].is_dynamic:
shape[2] = 0
else:
shape[1] = 0
inputs[input_name] = ov.Tensor(model_inputs.get_element_type(), shape.get_shape())

# initialize the rest of inputs (e.g. beam_idx for stateful models)
for val in self.model.inputs:
name = val.any_name
if name in inputs:
continue
shape = list(val.partial_shape.get_min_shape())
shape[0] = batch_size
inputs[name] = np.zeros(shape)
if self.backend == BackendType.TORCH:
for input_name in inputs:
inputs[input_name] = torch.from_numpy(inputs[input_name])
return inputs

return transform_fn

def prepare_calibration_dataset(self):
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train", revision="b08601e")
dataset = dataset.filter(lambda example: len(example["text"]) > 128)

self.calibration_dataset = nncf.Dataset(dataset, self.get_transform_calibration_fn())

def cleanup_cache(self):
Expand All @@ -164,8 +187,12 @@ def collect_data_from_stdout(self, stdout: str):
def save_compressed_model(self) -> None:
if self.backend == BackendType.FP32:
return
ov.serialize(self.model, self.output_model_dir / self.OV_MODEL_NAME)
self.model_hf._save_config(self.output_model_dir)

if self.backend == BackendType.OV:
ov.serialize(self.model, self.output_model_dir / self.OV_MODEL_NAME)
self.model_hf._save_config(self.output_model_dir)
elif self.backend == BackendType.TORCH:
export_from_model(self.model_hf, self.output_model_dir, stateful=False, compression_option="fp32")

def get_num_compressed(self) -> None:
AdiKsOnDev marked this conversation as resolved.
Show resolved Hide resolved
"""
Expand All @@ -174,7 +201,12 @@ def get_num_compressed(self) -> None:
num_int8 = 0
num_int4 = 0

for node in self.model.get_ops():
if self.backend == BackendType.TORCH:
model = ov.Core().read_model(self.output_model_dir / self.OV_MODEL_NAME)
else:
model = self.model

for node in model.get_ops():
for i in range(node.get_output_size()):
if node.get_output_element_type(i).get_type_name() in ["i8", "u8"]:
num_int8 += 1
Expand All @@ -192,8 +224,11 @@ def _dump_model_fp32(self) -> None:
Dump IRs of fp32 models, to help debugging. The test cases may share the same fp32 model, therefore it is saved
to the dedicated shared folder.
"""
self.model_hf.save_pretrained(self.fp32_model_dir)
self.model_hf._save_config(self.fp32_model_dir)
if self.backend == BackendType.OV:
self.model_hf.save_pretrained(self.fp32_model_dir)
self.model_hf._save_config(self.fp32_model_dir)
elif self.backend == BackendType.TORCH:
export_from_model(self.model_hf, self.fp32_model_dir, stateful=False, compression_option="fp32")

def _compress(self):
"""
Expand Down
Loading