Skip to content

Commit

Permalink
Add docstring for PT2E and HQQ (#1937)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Jul 19, 2024
1 parent 437c8e7 commit 296c5d4
Show file tree
Hide file tree
Showing 17 changed files with 454 additions and 114 deletions.
4 changes: 4 additions & 0 deletions .azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
/neural-compressor/neural_compressor/strategy
/neural-compressor/neural_compressor/training.py
/neural-compressor/neural_compressor/utils
/neural_compressor/torch/algorithms/pt2e_quant
/neural_compressor/torch/export
/neural_compressor/common
/neural_compressor/torch/algorithms/weight_only/hqq
1 change: 1 addition & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The PT2E-related modules."""


from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
Expand Down
39 changes: 36 additions & 3 deletions neural_compressor/torch/algorithms/pt2e_quant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Some code snippets are taken from the X86InductorQuantizer tutorial.
# https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html

"""The quantizer using PT2E path."""

from typing import Any

Expand All @@ -30,13 +30,24 @@


class W8A8PT2EQuantizer(Quantizer):
"""The W8A8 quantizer using PT2E."""

is_dynamic = False

def __init__(self, quant_config=None):
"""Initialize the quantizer."""
super().__init__(quant_config)

@staticmethod
def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer:
"""Updates the quantizer based on the given quantization configuration.
Args:
quant_config (dict): The quantization configuration. Defaults to None.
Returns:
X86InductorQuantizer: The updated quantizer object.
"""
if not quant_config:
quantizer = X86InductorQuantizer()
quantizer.set_global(
Expand All @@ -47,9 +58,18 @@ def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuan
return quantizer

def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule:
"""Prepare the model for calibration.
"""Prepares the model for calibration.
Create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
Args:
model (GraphModule): The model to be prepared for calibration.
example_inputs (tuple, optional): Example inputs to be used for calibration. Defaults to None.
inplace (bool, optional): Whether to modify the model in-place or return a new prepared model.
Defaults to True.
Returns:
GraphModule: The prepared model.
"""
quant_config = self.quant_config
assert model._exported, "The model should be exported before preparing it for calibration."
Expand All @@ -58,7 +78,14 @@ def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args,
return prepared_model

def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
"""Convert the calibrated model into qdq mode."""
"""Convert the calibrated model into qdq mode.
Args:
model (GraphModule): The prepared model.
Returns:
GraphModule: The converted quantized model.
"""
fold_quantize = kwargs.get("fold_quantize", False)
converted_model = convert_pt2e(model, fold_quantize=fold_quantize)
logger.warning("Converted the model in qdq mode, please compile it to accelerate inference.")
Expand All @@ -67,6 +94,12 @@ def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
return converted_model

def half_precision_transformation(self, model, config):
"""Applies half-precision transformation to the given model in-place.
Args:
model: The model to apply the transformation to.
config: The configuration for the transformation.
"""
half_precision_node_set = hp_rewriter.get_half_precision_node_set(model, config)
logger.info("Try to convert %d nodes to half precision.", len(half_precision_node_set))
hp_rewriter.transformation(model, half_precision_node_set)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rewrite the FP32 operators to FP16 or BF16 operators."""

from dataclasses import dataclass
from functools import partial
Expand All @@ -34,6 +35,14 @@

@dataclass
class PatternPair:
"""Represents a pair of patterns used for search and replacement in a graph.
Attributes:
fn (TorchFuncType): The function type associated with the pattern pair.
search_pattern (torch.fx.GraphModule): The search pattern to be matched in the graph.
replace_pattern (torch.fx.GraphModule): The replacement pattern to be used when a match is found.
"""

fn: TorchFuncType
search_pattern: torch.fx.GraphModule
replace_pattern: torch.fx.GraphModule
Expand Down Expand Up @@ -101,6 +110,15 @@ def _register_pattern_pair(dtype: torch.dtype) -> None:


def get_filter_fn(node_list, fn):
"""Filter function to check if a node with the target operator is in the given `node_list`.
Args:
node_list (list): List of nodes to check against.
fn (str): Target operator.
Returns:
bool: True if the node with the target operator is in the `node_list`, False otherwise.
"""
target_op = FN_ATEN_OPS_MAPPING[fn]

def is_target_node_in_candidate_list(match, original_graph, pattern_graph):
Expand All @@ -119,6 +137,16 @@ def is_target_node_in_candidate_list(match, original_graph, pattern_graph):


def apply_single_pattern_pair(gm: torch.fx.GraphModule, pattern_pair: PatternPair, node_list):
"""Applies a single pattern pair to a given GraphModule.
Args:
gm (torch.fx.GraphModule): The GraphModule to apply the pattern pair to.
pattern_pair (PatternPair): The pattern pair containing the search and replace patterns.
node_list: The list of nodes to filter for pattern matching.
Returns:
List[Match]: A list of Match objects representing the matches found after applying the pattern pair.
"""
filter_fn = get_filter_fn(node_list, pattern_pair.fn)
match_and_replacements = subgraph_rewriter.replace_pattern_with_filters(
gm=gm,
Expand All @@ -133,6 +161,14 @@ def apply_single_pattern_pair(gm: torch.fx.GraphModule, pattern_pair: PatternPai


def get_unquantized_node_set(gm: torch.fx.GraphModule):
"""Retrieves the set of unquantized nodes from a given GraphModule.
Args:
gm (torch.fx.GraphModule): The GraphModule to retrieve unquantized nodes from.
Returns:
set: A set containing the unquantized nodes.
"""
unquantized_node_set = set()
for node in gm.graph.nodes:
if meta := getattr(node, "meta"):
Expand Down Expand Up @@ -180,7 +216,17 @@ def _parse_node_candidate_set_from_user_config(config, gm):


def get_half_precision_node_set(gm, config):
"""Intersection between `unquantized_node_set` and `node_set_from_user_config`"""
"""Retrieves a set of nodes from the given graph model (gm) that are candidates for conversion to half precision.
The result is the intersection between `unquantized_node_set` and `node_set_from_user_config`.
Args:
gm (GraphModel): The graph model to search for nodes.
config (dict): User configuration for node candidate set.
Returns:
set: A set of nodes that are candidates for conversion to half precision.
"""
# TODO: implement it, current return all unquantized_node_set

node_set_from_user_config = _parse_node_candidate_set_from_user_config(config, gm)
Expand Down
17 changes: 17 additions & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Save and load the quantized model."""


import json
import os
Expand All @@ -22,6 +24,13 @@


def save(model, example_inputs, output_dir="./saved_results"):
"""Save the quantized model and its configuration.
Args:
model (torch.nn.Module): The quantized model to be saved.
example_inputs (torch.Tensor or tuple of torch.Tensor): Example inputs used for tracing the model.
output_dir (str, optional): The directory where the saved results will be stored. Defaults to "./saved_results".
"""
os.makedirs(output_dir, exist_ok=True)
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
Expand All @@ -37,6 +46,14 @@ def save(model, example_inputs, output_dir="./saved_results"):


def load(output_dir="./saved_results"):
"""Load a quantized model from the specified output directory.
Args:
output_dir (str): The directory where the quantized model is saved. Defaults to "./saved_results".
Returns:
torch.nn.Module: The loaded quantized model.
"""
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
loaded_quantized_ep = torch.export.load(qmodel_file_path)
return loaded_quantized_ep.module()
22 changes: 22 additions & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for PT2E quantization."""

from typing import Dict

Expand All @@ -24,6 +25,18 @@


def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
"""Create a quantization specification based on the given configuration.
Args:
dtype (str): The desired data type for quantization. Valid options are "int8" and "uint8".
sym (bool): Whether to use symmetric quantization or not.
granularity (str): The granularity of quantization. Valid options are "per_channel" and "per_tensor".
algo (str): The algorithm to use for quantization. Valid options are "placeholder", "minmax", and "kl".
is_dynamic (bool, optional): Whether to use dynamic quantization or not. Defaults to False.
Returns:
QuantizationSpec: The created quantization specification.
"""
dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8}
select_dtype = dtype_mapping[dtype]
min_max_mapping = {torch.int8: (-128, 127), torch.uint8: (0, 255)}
Expand Down Expand Up @@ -76,6 +89,15 @@ def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> Quant


def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86InductorQuantizer:
"""Creates an instance of X86InductorQuantizer based on the given configuration.
Args:
config: The configuration object containing the quantization settings.
is_dynamic: A boolean indicating whether dynamic quantization is enabled.
Returns:
An instance of X86InductorQuantizer initialized with the provided configuration.
"""
quantizer = xiq.X86InductorQuantizer()
# set global
global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HQQ-related modules."""

from .quantizer import HQQuantizer
from .config import HQQModuleConfig, QTensorConfig
Loading

0 comments on commit 296c5d4

Please sign in to comment.