Skip to content

Commit

Permalink
Supports multi-device model inference and wrapped forward functions (#…
Browse files Browse the repository at this point in the history
…2253)

### Changes

Supports multi-device model inference and wrapped forward functions

### Reason for changes

Support tracing "bigscience/bloomz-560m" model from HF

### Related tickets

N/A

### Tests

test_no_self_forward,  test_multidevice_model
  • Loading branch information
alexsu52 authored Nov 14, 2023
1 parent 13e794b commit 75a3403
Show file tree
Hide file tree
Showing 12 changed files with 199 additions and 69 deletions.
2 changes: 1 addition & 1 deletion examples/torch/common/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def export_model(ctrl: CompressionAlgorithmController, save_path: str, no_strip_

model = model.eval().cpu()

export_args = get_export_args(model)
export_args = get_export_args(model, device="cpu")
input_names = generate_input_names_list(count_tensors(export_args))

with torch.no_grad():
Expand Down
16 changes: 14 additions & 2 deletions nncf/torch/dynamic_graph/graph_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@

import torch

from nncf.common.logging import nncf_logger
from nncf.torch.dynamic_graph.context import TracingContext
from nncf.torch.dynamic_graph.graph import DynamicGraph
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.dynamic_graph.io_handling import LoaderInputInfo
from nncf.torch.dynamic_graph.io_handling import ModelInputInfo
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_multidevice


class GraphTracer:
Expand Down Expand Up @@ -65,8 +69,16 @@ def default_dummy_forward_fn(model):
from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk
from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_outputs_with_objwalk

device = get_model_device(model)
args, kwargs = input_info.get_forward_inputs(device=str(device))
device = None
if isinstance(input_info, (FillerInputInfo, LoaderInputInfo)):
if is_multidevice(model):
nncf_logger.warning(
"Multidevice model detected when tracing the model's dynamic graph - will pass example "
"inputs to the model as-is without changing their device."
)
else:
device = get_model_device(model)
args, kwargs = input_info.get_forward_inputs(device)

if with_input_tracing:
if wrap_inputs_fn is None:
Expand Down
38 changes: 22 additions & 16 deletions nncf/torch/dynamic_graph/io_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from functools import partial
from inspect import Parameter
from inspect import Signature
from typing import Any, Dict, List, Optional, Protocol, Set, Tuple, Type
from typing import Any, Dict, List, Optional, Protocol, Set, Tuple, Type, Union

import torch

Expand Down Expand Up @@ -87,7 +87,7 @@ class ModelInputInfo(abc.ABC):
"""

@abc.abstractmethod
def get_forward_inputs(self, device: str = None) -> Tuple[Tuple, Dict]:
def get_forward_inputs(self, device: Optional[Union[str, torch.device]] = None) -> Tuple[Tuple, Dict]:
"""
Returns the tuple of (args, kwargs) for passing into the compressed model's forward method when necessary.
The returned arguments should be such that the model's forward with these arguments executes the main
Expand Down Expand Up @@ -204,7 +204,9 @@ def from_nncf_config(cls, config: NNCFConfig):
return FillerInputInfo(elements)
raise RuntimeError("Invalid input_infos specified in config - should be either dict or list of dicts")

def get_forward_inputs(self, device: str = None) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
def get_forward_inputs(
self, device: Optional[Union[str, torch.device]] = None
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
args_list = []
kwargs = {}
for fe in self.elements:
Expand All @@ -227,7 +229,7 @@ def __init__(self, forward_args: Tuple, forward_kwargs: Dict):
self._forward_args = forward_args
self._forward_kwargs = forward_kwargs

def get_forward_inputs(self, device: str = None) -> Tuple[Tuple, Dict]:
def get_forward_inputs(self, device: Optional[Union[str, torch.device]] = None) -> Tuple[Tuple, Dict]:
if device is None:
return self._forward_args, self._forward_kwargs
to_device_fn = partial(torch.Tensor.to, device=device)
Expand All @@ -237,34 +239,38 @@ def get_forward_inputs(self, device: str = None) -> Tuple[Tuple, Dict]:
kwargs_at_device = objwalk(kwargs_copy, is_tensor, to_device_fn)
return args_at_device, kwargs_at_device


class ExampleInputInfo(ExactInputsInfo):
@classmethod
def from_example_input(cls, example_input: Any) -> "ExactInputsInfo":
def from_example_input(cls, example_input: Any) -> "ExampleInputInfo":
"""
Builds an ExactInputsInfo object based on the example input.
Builds an ExampleInputInfo object based on the example input.
:param dataset: An nncf.Dataset whose first element will be used as an example model input
:return: An initialized ExactInputsInfo object.
:return: An initialized ExampleInputInfo object.
"""
if isinstance(example_input, tuple):
return ExactInputsInfo(example_input, {})
return ExampleInputInfo(example_input, {})
if isinstance(example_input, dict):
return ExactInputsInfo(tuple(), example_input)
return ExactInputsInfo((example_input,), {})
return ExampleInputInfo(tuple(), example_input)
return ExampleInputInfo((example_input,), {})

@classmethod
def from_nncf_dataset(cls, dataset: Dataset) -> "ExactInputsInfo":
def from_nncf_dataset(cls, dataset: Dataset) -> "ExampleInputInfo":
"""
Checks the first element of the provided nncf.Dataset and builds an ExactInputsInfo object that would
Checks the first element of the provided nncf.Dataset and builds an ExampleInputInfo object that would
provide the same input to the model at corresponding compression stages.
:param dataset: An nncf.Dataset whose first element will be used as an example model input
:return: An initialized ExactInputsInfo object.
:return: An initialized ExampleInputInfo object.
"""
example_input = next(iter(dataset.get_inference_data()))
return cls.from_example_input(example_input)


class LoaderInputInfo(ExactInputsInfo):
@classmethod
def from_nncf_config_dataloaders(cls, config: NNCFConfig) -> Optional["ExactInputsInfo"]:
def from_nncf_config_dataloaders(cls, config: NNCFConfig) -> Optional["LoaderInputInfo"]:
"""
Examines the user-provided structures registered with the NNCFConfig instance used for compression to find
structures that contain a dataloader. The dataloader's first element is used to provide an example input to
Expand All @@ -273,7 +279,7 @@ def from_nncf_config_dataloaders(cls, config: NNCFConfig) -> Optional["ExactInpu
:param config: An nncf.NNCFConfig instance. Must have at least one NNCFExtraConfigStruct attached that can
provide a dataloader (these are listed in nncf.torch.dynamic_graph.io_handling.EXTRA_STRUCTS_WITH_DATALOADERS)
:return: An initialized ExactInputsInfo object.
:return: An initialized LoaderInputInfo object.
"""
extra_structs = config.get_all_extra_structs()
for extra_struct in extra_structs:
Expand All @@ -283,7 +289,7 @@ def from_nncf_config_dataloaders(cls, config: NNCFConfig) -> Optional["ExactInpu
wrapped_dataloader = wrap_dataloader_for_init(dataloader)
dataloader_output = next(iter(wrapped_dataloader))
args, kwargs = wrapped_dataloader.get_inputs(dataloader_output)
return ExactInputsInfo(args, kwargs)
return LoaderInputInfo(args, kwargs)
# config extra structs had no suitable dataloaders
return None

Expand Down
8 changes: 5 additions & 3 deletions nncf/torch/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def counter_fn(x: torch.Tensor) -> torch.Tensor:
return count


def get_export_args(model: NNCFNetwork, model_args: Optional[Tuple[Any, ...]] = None) -> Tuple:
args, kwargs = model.nncf.input_infos.get_forward_inputs()
def get_export_args(
model: NNCFNetwork, model_args: Optional[Tuple[Any, ...]] = None, device: Optional[str] = None
) -> Tuple:
args, kwargs = model.nncf.input_infos.get_forward_inputs(device)

if model_args is not None:
args = tuple(list(args) + list(model_args[:-1]))
Expand Down Expand Up @@ -142,7 +144,7 @@ def _export_to_onnx(self, save_path: str, opset_version: int) -> None:
original_device = get_model_device(self._model)
model = self._model.eval().cpu()

export_args = get_export_args(self._model, model_args=self._model_args)
export_args = get_export_args(self._model, model_args=self._model_args, device="cpu")

if self._input_names is not None:
input_names = self._input_names
Expand Down
7 changes: 4 additions & 3 deletions nncf/torch/model_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
from nncf.torch.dynamic_graph.graph_tracer import WrapInputsFnType
from nncf.torch.dynamic_graph.graph_tracer import WrapOutputsFnType
from nncf.torch.dynamic_graph.io_handling import EXTRA_STRUCTS_WITH_DATALOADERS
from nncf.torch.dynamic_graph.io_handling import ExactInputsInfo
from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.dynamic_graph.io_handling import LoaderInputInfo
from nncf.torch.dynamic_graph.io_handling import ModelInputInfo
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.utils import is_dist_avail_and_initialized
Expand Down Expand Up @@ -167,7 +168,7 @@ def get_input_info_from_config(config: NNCFConfig) -> ModelInputInfo:
nncf_logger.debug(
"Config has no 'input_info' section, trying to use dataloader output as model inputs " "for graph building."
)
exact_info = ExactInputsInfo.from_nncf_config_dataloaders(config)
exact_info = LoaderInputInfo.from_nncf_config_dataloaders(config)
if exact_info is not None:
return exact_info
raise RuntimeError(
Expand Down Expand Up @@ -323,7 +324,7 @@ def wrap_model(model: torch.nn.Module, example_input: Any) -> NNCFNetwork:
:return: A model wrapped by NNCFNetwork.
"""

input_info = ExactInputsInfo.from_example_input(example_input)
input_info = ExampleInputInfo.from_example_input(example_input)

with training_mode_switcher(model, is_training=False):
nncf_network = NNCFNetwork(model, input_info=input_info)
Expand Down
12 changes: 10 additions & 2 deletions nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.nncf_network import PTInsertionPoint
from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_multidevice


class PTModelTransformer(ModelTransformer):
Expand Down Expand Up @@ -119,15 +121,21 @@ def _apply_quantizer_insertion_transformations(
model.nncf.register_compression_module_type(compression_model_type)

insertion_commands: List[PTInsertionCommand] = []
device = None
if not is_multidevice(model):
device = get_model_device(model)

for transformation_command in transformations:
target_point: PTTargetPoint = transformation_command.target_point
fn = transformation_command.quantizer
quantizer_module = transformation_command.quantizer
if device is not None:
quantizer_module = quantizer_module.to(device)
fn = quantizer_module

if target_point.type is not TargetType.OPERATION_WITH_WEIGHTS:
quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id)
storage_key = str(quantizer_id)
model.nncf.add_compression_module(storage_key, transformation_command.quantizer, compression_model_type)
model.nncf.add_compression_module(storage_key, quantizer_module, compression_model_type)
fn = ExternalQuantizerCallHook(model.nncf.get_tracing_context(), storage_key)

insertion_commands.append(
Expand Down
4 changes: 4 additions & 0 deletions nncf/torch/nncf_module_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT
from nncf.torch.layers import UNWRAPPED_USER_MODULES
from nncf.torch.layers import add_nncf_functionality_to_user_module
from nncf.torch.utils import get_model_device


def is_nncf_module(module: nn.Module) -> bool:
Expand Down Expand Up @@ -181,11 +182,14 @@ def replace_modules_by_nncf_modules(
scope_set, ignored_scopes, target_scopes, eval_op_scopes
) and not _is_module_only_in_user_module(scope_set)
if should_process:
device = get_model_device(module)

if custom_replacer is not None:
replaced_module = custom_replacer(module)
else:
replaced_module = nncf_module_from(module)

replaced_module.to(device)
inter_dict[replaced_module] = scope_set

new_scope_set = set()
Expand Down
32 changes: 18 additions & 14 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_original_forward(self) -> Callable:
have its 0-th implicit `self` argument bound to the model object.
"""
if self._original_instance_forward is not None:
return functools.partial(self._original_instance_forward, self._model_ref)
return self._original_instance_forward
return functools.partial(self._original_unbound_forward, self._model_ref)

@contextmanager
Expand Down Expand Up @@ -224,8 +224,6 @@ def __init__(
self._user_dummy_forward_fn = dummy_forward_fn
self._kd_loss_handler = None

device = get_model_device(model)

if wrap_inputs_fn is not None:
self._wrap_inputs_fn = wrap_inputs_fn
elif self._input_info is not None:
Expand Down Expand Up @@ -254,7 +252,7 @@ def __init__(
eval_op_scopes = self._collect_eval_op_scopes(model, _orig_graph_build_forward_fn)

# all modules called in eval mode should be replaced prior to graph building
self._replace_modules_by_nncf_modules(model, device, eval_op_scopes)
self._replace_modules_by_nncf_modules(model, eval_op_scopes)

_orig_context = TracingContext()

Expand Down Expand Up @@ -458,13 +456,10 @@ def wrapped_user_dummy_forward_fn(*args, **kwargs):

return wrapped_user_dummy_forward_fn

def _replace_modules_by_nncf_modules(
self, model: torch.nn.Module, device: torch.device, eval_op_scopes: List[Scope] = None
):
def _replace_modules_by_nncf_modules(self, model: torch.nn.Module, eval_op_scopes: List[Scope] = None):
_, self._nncf_replaced_modules = replace_modules_by_nncf_modules(
model, ignored_scopes=self._ignored_scopes, target_scopes=self._target_scopes, eval_op_scopes=eval_op_scopes
)
model.to(device)
return model

def get_nncf_module_scopes(self) -> List[List[Scope]]:
Expand Down Expand Up @@ -836,10 +831,12 @@ def __call__(
new_class.forward = _get_nncf_forward_function_with_signature(inspect.signature(original_class.forward))

# In case of overriding forward by code like `model.forward = wrapper(model.forward)`
forward_inst_attr_fn = original_model.__dict__.get("forward")
if forward_inst_attr_fn is not None:
new_inst_forward = _get_nncf_forward_function_with_signature(inspect.signature(forward_inst_attr_fn))
original_model.__dict__["forward"] = functools.partial(new_inst_forward, original_model)
original_instance_forward = original_model.__dict__.get("forward")
if original_instance_forward is not None:
bound_new_instance_forward = _get_nncf_forward_function_with_signature(
inspect.signature(original_instance_forward), bind_self_to=original_model
)
original_model.__dict__["forward"] = bound_new_instance_forward

# Make resulting class keep __module__ attributes of the original class,
# otherwise these will point to NNCF
Expand Down Expand Up @@ -884,15 +881,22 @@ def __eq__(cls, other):
return other is NNCFNetwork


def _get_nncf_forward_function_with_signature(signature: inspect.Signature):
def _get_nncf_forward_function_with_signature(
signature: inspect.Signature, bind_self_to: torch.nn.Module = None
) -> Callable:
"""
Create forward function with copy signature of forward function.
Creates a function that executes code from NNCFNetwork.forward, but with a final signature equal to the provided
one.
:param signature: Signature of function that will used for forward function.
:param bind_self_to: If provided, will bind the `self` argument of the returned function to the provided model
object. This should be the model object that we are currently constructing the NNCFNetwork with.
:return: New copy of function NNCFNetwork.forward with specified signature.
"""
fn = NNCFNetwork.forward
new_forward = types.FunctionType(fn.__code__, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__)
new_forward.__dict__.update(fn.__dict__)
if bind_self_to is not None:
new_forward = functools.partial(new_forward, bind_self_to)
new_forward.__signature__ = signature
if is_debug():
new_forward = debuggable_forward(new_forward)
Expand Down
20 changes: 19 additions & 1 deletion nncf/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import random
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Dict, List
from typing import Any, Dict, Generator, List

import numpy as np
import torch
Expand Down Expand Up @@ -420,6 +420,24 @@ def get_model_device(model: torch.nn.Module) -> torch.device:
return device


def get_all_model_devices_generator(model: torch.nn.Module) -> Generator[torch.device, None, None]:
for p in model.parameters():
yield p.device


def is_multidevice(model: torch.nn.Module) -> bool:
device_generator = get_all_model_devices_generator(model)
try:
curr_device = next(device_generator)
except StopIteration: # no parameters
return False

for d in device_generator:
if d != curr_device:
return True
return False


def get_model_dtype(model: torch.nn.Module) -> torch.dtype:
try:
dtype = next(model.parameters()).dtype
Expand Down
Loading

0 comments on commit 75a3403

Please sign in to comment.