diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 881f40f..8570aec 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -11,6 +11,7 @@ import torch import torch.distributed as dist import torch.nn as nn +import torch.nn.utils.parametrize as parametrize from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear @@ -19,6 +20,7 @@ e4m3_dtype, e5m2_dtype, ) +from float8_experimental.inference import Float8InferenceLinear, QuantConfig from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor log = logging.getLogger(__name__) @@ -175,6 +177,19 @@ def swap_linear_with_float8_linear( emulate: bool = False, linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, ) -> Optional[nn.Module]: + """Entrypoint for swapping linear layers with float8 for an existing nn.Module + + Note: + If applied to a root-level nn.Linear, the module will not be modified in place + and returned instead + + Args: + module: The root-level nn.Module to modify + module_cls: The class to swap the linear layers with + skip_fqn_list: List of module FQNs to skip during conversion. + emulate: Whether to enable float8 emulation. + linear_layer_filter: If specified, only the linear layers that pass the filter function will be swapped. + """ return swap_linear_layers( module, lambda m: module_cls.from_float(m, emulate=emulate), @@ -183,6 +198,39 @@ def swap_linear_with_float8_linear( ) +def quantize_to_float8( + module: nn.Module, + quant_config: QuantConfig, + *, + skip_fqn_list: Optional[List[str]] = None, + use_fast_accum: bool = True, +) -> Optional[nn.Module]: + """ + Converts torch.nn.Linear layers in the given module to Float8InferenceLinear. + + Note: + If applied to a root-level nn.Linear, the module will not be modified in place + and returned instead + + Args: + module: The module to modify. + quant_config: Quantization configuration for Float8 conversion. + skip_fqn_list: List of module FQNs to skip during conversion. + use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True. + + Returns: + nn.Module: The modified module with applicable Linear layers converted to Float8. + + Raises: + AssertionError: If a root-level nn.Linear with children is encountered. + """ + return swap_linear_layers( + module, + lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum), + skip_fqn_list=skip_fqn_list, + ) + + def get_float8_layers(model: torch.nn.Module): """Iterates through the model and returns all the Float8Linear layers. Args: @@ -347,3 +395,54 @@ def inner_func(): for child in fp8_layers: # Set a flag to signal amaxes/scales are ready child.amax_and_scale_synced = True + + +# TODO: Remove me when export utils landing upstream +class UnwrapTensorSubclass(torch.nn.Module): + def forward(self, *tensors): + todo = list(tensors) + for tp, meta, inner_tensors in reversed(self.rebuild_stack): + nb_tensor = len(inner_tensors) + inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])} + todo = todo[nb_tensor:] + rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None) + todo.append(rebuilt) + + assert len(todo) == 1 + return todo[0] + + def right_inverse(self, tensor: torch.Tensor) -> List[torch.Tensor]: + assert type(tensor) is not torch.Tensor, "Expected a wrapper tensor subclass!" + rebuild_stack = [] + plain_tensors = [] + todo = [tensor] + while todo: + obj = todo.pop() + inner_tensors, metadata = obj.__tensor_flatten__() + rebuild_stack.append((type(obj), metadata, inner_tensors)) + for attr_name in inner_tensors: + val = getattr(obj, attr_name) + if type(val) is torch.Tensor: + plain_tensors.append(val) + else: + assert isinstance(val, torch.Tensor) + todo.append(val) + + self.rebuild_stack = rebuild_stack + + return plain_tensors + + +def unwrap_tensor_subclass(model, filter_fn=None) -> nn.Module: + for _, child in model.named_children(): + if ( + isinstance(child, Float8InferenceLinear) + and hasattr(child, "weight") + and type(child.weight) is not torch.Tensor + and isinstance(child.weight, torch.Tensor) + ): + parametrize.register_parametrization( + child, "weight", UnwrapTensorSubclass() + ) + unwrap_tensor_subclass(child) + return model diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 1c931ee..42c3bdb 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -10,13 +10,12 @@ from dataclasses import dataclass from enum import auto, Enum -from typing import List, Optional +from typing import Optional import float8_experimental.config as config import torch import torch.nn as nn -from float8_experimental.float8_linear_utils import swap_linear_layers from float8_experimental.float8_tensor import ( Float8Tensor, @@ -191,36 +190,3 @@ def cast_to_float8_e4m3_inference( else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) ) return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) - - -def quantize_to_float8( - module: nn.Module, - quant_config: QuantConfig, - *, - skip_fqn_list: Optional[List[str]] = None, - use_fast_accum: bool = True, -) -> Optional[nn.Module]: - """ - Converts torch.nn.Linear layers in the given module to Float8InferenceLinear. - - Note: - If applied to a root-level nn.Linear, the module will not be modified in place - and returned instead - - Args: - module (nn.Module): The module to modify. - quant_config (QuantConfig): Quantization configuration for Float8 conversion. - skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion. - use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True. - - Returns: - nn.Module: The modified module with applicable Linear layers converted to Float8. - - Raises: - AssertionError: If a root-level nn.Linear with children is encountered. - """ - return swap_linear_layers( - module, - lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum), - skip_fqn_list=skip_fqn_list, - ) diff --git a/test/test_base.py b/test/test_base.py index 7ce0b7b..9fa9df2 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -22,6 +22,7 @@ get_float8_linear, linear_requires_sync, LinearType, + quantize_to_float8, swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) @@ -39,11 +40,7 @@ FP8_TYPES, tensor_to_scale, ) -from float8_experimental.inference import ( - ActivationCasting, - QuantConfig, - quantize_to_float8, -) +from float8_experimental.inference import ActivationCasting, QuantConfig random.seed(0) torch.manual_seed(0) diff --git a/test/test_inference_flows.py b/test/test_inference_flows.py index b0c00c6..679d47b 100644 --- a/test/test_inference_flows.py +++ b/test/test_inference_flows.py @@ -5,29 +5,34 @@ # LICENSE file in the root directory of this source tree. import copy import io +import os import random import unittest import pytest import torch + +import torch._inductor import torch.nn as nn import torch.nn.functional as F from float8_experimental.float8_dynamic_linear import Float8DynamicLinear -from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear +from float8_experimental.float8_linear_utils import ( + quantize_to_float8, + swap_linear_with_float8_linear, + unwrap_tensor_subclass, +) from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import compute_error from float8_experimental.inference import ( ActivationCasting, Float8InferenceLinear, QuantConfig, - quantize_to_float8, ) - +from torch.export._trace import _export as _export_private random.seed(0) torch.manual_seed(0) - is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) @@ -242,5 +247,57 @@ def test_fp8_save_and_load(self, dtype: torch.dtype): assert torch.all(og_out == new_out).item() +class TestFP8Export: + @unittest.skipIf( + not torch.cuda.is_available() or not is_H100, + "CUDA not available or on non H100 machine", + ) + def test_fp8_export(self): + export_model = FeedForward().to("cuda") + quant_config = QuantConfig(ActivationCasting.DYNAMIC) + quantize_to_float8(export_model, quant_config) + batch_size = 4 + num_tokens = 1024 + embedding_dim = 4096 + + inp = torch.randn( + batch_size, num_tokens, embedding_dim, device="cuda", dtype=torch.float32 + ) + example_args = (inp,) + + fp8_compile_model = copy.deepcopy(export_model) + fp8_compile_model = torch.compile(fp8_compile_model) + fp8_compile_out = fp8_compile_model(*example_args) + + # Export model with subclass weights + + export_model = unwrap_tensor_subclass(export_model) + + # Export the model + exported_model = _export_private( + export_model, + example_args, + strict=False, + pre_dispatch=False, + ) + + so_path = None + try: + # Compile the exported program to a .so using AOTInductor + with torch.no_grad(): + so_path = torch._inductor.aot_compile( + exported_model.module(), example_args + ) + + # Load and run the .so file in Python + res = torch._export.aot_load(so_path, device="cuda")(example_args) + torch.testing.assert_close(fp8_compile_out, res) + + finally: + # Cleanup: remove the .so file + if so_path and os.path.exists(so_path): + os.remove(so_path) + + if __name__ == "__main__": pytest.main([__file__])