Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add int8 quantization support #3058

Merged
merged 13 commits into from
Aug 28, 2024
18 changes: 18 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,13 +605,31 @@ def aten_ops_neg(
try:
import modelopt.torch.quantization as mtq # noqa: F401

assert torch.ops.trt.quantize_int8.default
assert torch.ops.trt.quantize_fp8.default
except Exception as e:
_LOGGER.warning(
"Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
)
else:

@dynamo_tensorrt_converter(torch.ops.trt.quantize_int8.default)
def aten_ops_quantize_int8(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.quantize.quantize_int8(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)

@dynamo_tensorrt_converter(torch.ops.trt.quantize_fp8.default)
def aten_ops_quantize_fp8(
ctx: ConversionContext,
Expand Down
44 changes: 38 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/impl/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from torch_tensorrt.fx.types import TRTTensor


def quantize_fp8(
def quantize(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
quantize_type: str,
input_tensor: TRTTensor,
scale: np.ndarray,
) -> TRTTensor:
Expand All @@ -26,20 +27,51 @@ def quantize_fp8(
input_tensor.dtype == trt.float32 or input_tensor.dtype == trt.float16
lanluo-nvidia marked this conversation as resolved.
Show resolved Hide resolved
):
raise ValueError(
f"quantize_fp8 converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
f"quantize {quantize_type} converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
)
if quantize_type not in ["fp8", "int8"]:
raise ValueError(
f"{quantize_type=} is not supported. Supported types: fp8 | int8"
)

scale = get_trt_tensor(ctx, scale, name + "_scale")
# Add Q node
quantize_layer = ctx.net.add_quantize(input_tensor, scale)
quantize_layer.set_output_type(0, trt.DataType.FP8)
if quantize_type == "int8":
lanluo-nvidia marked this conversation as resolved.
Show resolved Hide resolved
quantize_layer.set_output_type(0, trt.DataType.INT8)
else:
quantize_layer.set_output_type(0, trt.DataType.FP8)
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
q_output = quantize_layer.get_output(0)
# Add DQ node
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
# Set DQ layer precision to FP8
dequantize_layer.precision = trt.DataType.FP8
if quantize_type == "int8":
dequantize_layer.precision = trt.DataType.INT8
else:
# Set DQ layer precision to FP8
dequantize_layer.precision = trt.DataType.FP8
dq_output = dequantize_layer.get_output(0)

return dq_output


def quantize_fp8(
lanluo-nvidia marked this conversation as resolved.
Show resolved Hide resolved
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_tensor: TRTTensor,
scale: np.ndarray,
) -> TRTTensor:
return quantize(ctx, target, source_ir, name, "fp8", input_tensor, scale)


def quantize_int8(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_tensor: TRTTensor,
scale: np.ndarray,
) -> TRTTensor:
return quantize(ctx, target, source_ir, name, "int8", input_tensor, scale)
43 changes: 43 additions & 0 deletions tests/py/dynamo/models/test_models_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,46 @@ def calibrate_loop(model):
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)


@pytest.mark.unit
def test_base_int8(ir):
class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)

def forward(self, x):
x = self.linear1(x)
x = torch.nn.ReLU()(x)
x = self.linear2(x)
return x

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

def calibrate_loop(model):
"""Simple calibration function for testing."""
model(input_tensor)

input_tensor = torch.randn(1, 10).cuda()
model = SimpleNetwork().eval().cuda()

quant_cfg = mtq.INT8_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has INT8 qdq nodes at this point
output_pyt = model(input_tensor)

with torch.no_grad():
with export_torch_mode():
exp_program = torch.export.export(model, (input_tensor,))
lanluo-nvidia marked this conversation as resolved.
Show resolved Hide resolved
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions={torch.int8},
min_block_size=1,
debug=True,
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)
Loading