Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Adds a test comparing the output of torch.compile and export #295

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

Expand All @@ -19,6 +20,7 @@
e4m3_dtype,
e5m2_dtype,
)
from float8_experimental.inference import Float8InferenceLinear, QuantConfig
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -175,6 +177,19 @@ def swap_linear_with_float8_linear(
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> Optional[nn.Module]:
"""Entrypoint for swapping linear layers with float8 for an existing nn.Module

Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead

Args:
module: The root-level nn.Module to modify
module_cls: The class to swap the linear layers with
skip_fqn_list: List of module FQNs to skip during conversion.
emulate: Whether to enable float8 emulation.
linear_layer_filter: If specified, only the linear layers that pass the filter function will be swapped.
"""
return swap_linear_layers(
module,
lambda m: module_cls.from_float(m, emulate=emulate),
Expand All @@ -183,6 +198,39 @@ def swap_linear_with_float8_linear(
)


def quantize_to_float8(
module: nn.Module,
quant_config: QuantConfig,
*,
skip_fqn_list: Optional[List[str]] = None,
use_fast_accum: bool = True,
) -> Optional[nn.Module]:
"""
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.

Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead

Args:
module: The module to modify.
quant_config: Quantization configuration for Float8 conversion.
skip_fqn_list: List of module FQNs to skip during conversion.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.

Returns:
nn.Module: The modified module with applicable Linear layers converted to Float8.

Raises:
AssertionError: If a root-level nn.Linear with children is encountered.
"""
return swap_linear_layers(
module,
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
skip_fqn_list=skip_fqn_list,
)


def get_float8_layers(model: torch.nn.Module):
"""Iterates through the model and returns all the Float8Linear layers.
Args:
Expand Down Expand Up @@ -347,3 +395,54 @@ def inner_func():
for child in fp8_layers:
# Set a flag to signal amaxes/scales are ready
child.amax_and_scale_synced = True


# TODO: Remove me when export utils landing upstream
class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors):
todo = list(tensors)
for tp, meta, inner_tensors in reversed(self.rebuild_stack):
nb_tensor = len(inner_tensors)
inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])}
todo = todo[nb_tensor:]
rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None)
todo.append(rebuilt)

assert len(todo) == 1
return todo[0]

def right_inverse(self, tensor: torch.Tensor) -> List[torch.Tensor]:
assert type(tensor) is not torch.Tensor, "Expected a wrapper tensor subclass!"
rebuild_stack = []
plain_tensors = []
todo = [tensor]
while todo:
obj = todo.pop()
inner_tensors, metadata = obj.__tensor_flatten__()
rebuild_stack.append((type(obj), metadata, inner_tensors))
for attr_name in inner_tensors:
val = getattr(obj, attr_name)
if type(val) is torch.Tensor:
plain_tensors.append(val)
else:
assert isinstance(val, torch.Tensor)
todo.append(val)

self.rebuild_stack = rebuild_stack

return plain_tensors


def unwrap_tensor_subclass(model, filter_fn=None) -> nn.Module:
for _, child in model.named_children():
if (
isinstance(child, Float8InferenceLinear)
and hasattr(child, "weight")
and type(child.weight) is not torch.Tensor
and isinstance(child.weight, torch.Tensor)
):
parametrize.register_parametrization(
child, "weight", UnwrapTensorSubclass()
)
unwrap_tensor_subclass(child)
return model
36 changes: 1 addition & 35 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
from dataclasses import dataclass

from enum import auto, Enum
from typing import List, Optional
from typing import Optional

import float8_experimental.config as config

import torch
import torch.nn as nn
from float8_experimental.float8_linear_utils import swap_linear_layers

from float8_experimental.float8_tensor import (
Float8Tensor,
Expand Down Expand Up @@ -191,36 +190,3 @@ def cast_to_float8_e4m3_inference(
else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
)
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)


def quantize_to_float8(
module: nn.Module,
quant_config: QuantConfig,
*,
skip_fqn_list: Optional[List[str]] = None,
use_fast_accum: bool = True,
) -> Optional[nn.Module]:
"""
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.

Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead

Args:
module (nn.Module): The module to modify.
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.

Returns:
nn.Module: The modified module with applicable Linear layers converted to Float8.

Raises:
AssertionError: If a root-level nn.Linear with children is encountered.
"""
return swap_linear_layers(
module,
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
skip_fqn_list=skip_fqn_list,
)
7 changes: 2 additions & 5 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_float8_linear,
linear_requires_sync,
LinearType,
quantize_to_float8,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
Expand All @@ -39,11 +40,7 @@
FP8_TYPES,
tensor_to_scale,
)
from float8_experimental.inference import (
ActivationCasting,
QuantConfig,
quantize_to_float8,
)
from float8_experimental.inference import ActivationCasting, QuantConfig

random.seed(0)
torch.manual_seed(0)
Expand Down
65 changes: 61 additions & 4 deletions test/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,34 @@
# LICENSE file in the root directory of this source tree.
import copy
import io
import os
import random
import unittest

import pytest

import torch

import torch._inductor
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_linear_utils import (
quantize_to_float8,
swap_linear_with_float8_linear,
unwrap_tensor_subclass,
)
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import compute_error
from float8_experimental.inference import (
ActivationCasting,
Float8InferenceLinear,
QuantConfig,
quantize_to_float8,
)

from torch.export._trace import _export as _export_private

random.seed(0)
torch.manual_seed(0)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


Expand Down Expand Up @@ -242,5 +247,57 @@ def test_fp8_save_and_load(self, dtype: torch.dtype):
assert torch.all(og_out == new_out).item()


class TestFP8Export:
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
"CUDA not available or on non H100 machine",
)
def test_fp8_export(self):
export_model = FeedForward().to("cuda")
quant_config = QuantConfig(ActivationCasting.DYNAMIC)
quantize_to_float8(export_model, quant_config)
batch_size = 4
num_tokens = 1024
embedding_dim = 4096

inp = torch.randn(
batch_size, num_tokens, embedding_dim, device="cuda", dtype=torch.float32
)
example_args = (inp,)

fp8_compile_model = copy.deepcopy(export_model)
fp8_compile_model = torch.compile(fp8_compile_model)
fp8_compile_out = fp8_compile_model(*example_args)

# Export model with subclass weights

export_model = unwrap_tensor_subclass(export_model)

# Export the model
exported_model = _export_private(
export_model,
example_args,
strict=False,
pre_dispatch=False,
)

so_path = None
try:
# Compile the exported program to a .so using AOTInductor
with torch.no_grad():
so_path = torch._inductor.aot_compile(
exported_model.module(), example_args
)

# Load and run the .so file in Python
res = torch._export.aot_load(so_path, device="cuda")(example_args)
torch.testing.assert_close(fp8_compile_out, res)

finally:
# Cleanup: remove the .so file
if so_path and os.path.exists(so_path):
os.remove(so_path)


if __name__ == "__main__":
pytest.main([__file__])
Loading