Skip to content

Commit

Permalink
Graph build from dataloader (#2196)
Browse files Browse the repository at this point in the history
### Changes
Extends `ModelInputInfo` mechanism used to specify inputs to
`NNCFNetwork` for graph building/exporting - now the input info can be
specified either as `FillerInputInfo`, which functions pretty much the
same as before and uses NNCF config file as the source of specification
for the input tensors, or as `ExactInputInfo`, which allows to specify
exact forward arguments for graph building. The latter is used to build
the model graph based on outputs of dataloaders attached to `NNCFConfig`
in the QAT API if the "input_info" field is not specified in
`NNCFConfig`, and also in the PTQ API flow to build the graph based on
the output of the calibration dataset.

### Reason for changes
Previously the PTQ API had to specify own `wrap_inputs_fn`,
`wrap_outputs_fn`, `dummy_forward_fn` to make NNCFNetwork build its
graph based on the outputs of the calibration dataloader - these
functions had to be mostly copy-pasted from the QAT approach to preserve
basic NNCF PT functionality such as traced tensor expiry, same tensor
replication etc. The new approach allows code reuse. Also the QAT use
cases where the init dataloaders are specified are made easier since
"input_info" fields in the NNCFConfig may now be omitted.

### Related tickets
N/A

### Tests

tests.torch.test_graph_building.test_input_info_args_are_passed_into_forward
tests.torch.test_graph_building.test_filler_input_info_arg_generation

tests.torch.test_graph_building.test_compressed_model_creation_can_build_exact_input_infos_from_dataloader_in_config

tests.torch.ptq.test_quantize_model_helpers.test_create_nncf_network_with_nncf_dataset
  • Loading branch information
vshampor authored Nov 7, 2023
1 parent 1ce5852 commit e00f6b7
Show file tree
Hide file tree
Showing 38 changed files with 803 additions and 572 deletions.
6 changes: 3 additions & 3 deletions examples/torch/classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from nncf.config.utils import is_accuracy_aware_training
from nncf.torch import create_compressed_model
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.dynamic_graph.graph_tracer import create_input_infos
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.initialization import default_criterion_fn
from nncf.torch.initialization import register_default_init_args
from nncf.torch.structures import ExecutionParameters
Expand Down Expand Up @@ -450,8 +450,8 @@ def create_datasets(config):
elif dataset_config in ["mock_32x32", "mock_299x299"]:
normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

input_info_list = create_input_infos(config)
image_size = input_info_list[0].shape[-1]
input_info = FillerInputInfo.from_nncf_config(config)
image_size = input_info.elements[0].shape[-1]
size = int(image_size / 0.875)
if dataset_config in ["cifar10", "cifar100_224x224", "cifar100"]:
list_val_transforms = [transforms.ToTensor(), normalize]
Expand Down
10 changes: 5 additions & 5 deletions examples/torch/object_detection/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from examples.torch.object_detection.datasets.voc0712 import VOCAnnotationTransform
from examples.torch.object_detection.datasets.voc0712 import VOCDetection
from examples.torch.object_detection.utils.augmentations import SSDAugmentation
from nncf.torch.dynamic_graph.graph_tracer import create_input_infos
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo

VOC_MEAN = (0.406, 0.456, 0.485)
VOC_STD = (0.255, 0.224, 0.229)
Expand All @@ -45,8 +45,8 @@ def get_training_dataset(dataset_name, path_to_annotations, path_to_imgs, config
# for VOC path_to_imgs = path_to_annotations = voc_root
assert dataset_name in ["voc", "coco"]
preprocessing = get_preprocessing(config)
input_info_list = create_input_infos(config)
image_size = input_info_list[0].shape[-1]
input_info = FillerInputInfo.from_nncf_config(config)
image_size = input_info.elements[0].shape[-1]
ssd_transform = SSDAugmentation(image_size, preprocessing.mean, preprocessing.std, preprocessing.normalize_coef)
if dataset_name == "voc":
training_dataset = VOCDetection(
Expand All @@ -72,8 +72,8 @@ def get_testing_dataset(dataset_name, path_to_annotations, path_to_imgs, config)
# for VOC path_to_imgs = path_to_annotations = voc_root
assert dataset_name in ["voc", "coco"]
preprocessing = get_preprocessing(config)
input_info_list = create_input_infos(config)
image_size = input_info_list[0].shape[-1]
input_info = FillerInputInfo.from_nncf_config(config)
image_size = input_info.elements[0].shape[-1]
transform = BaseTransform(image_size, preprocessing.mean, preprocessing.std, preprocessing.normalize_coef)
if dataset_name == "voc":
testing_dataset = VOCDetection(
Expand Down
6 changes: 3 additions & 3 deletions examples/torch/object_detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from nncf.config.utils import is_accuracy_aware_training
from nncf.torch import create_compressed_model
from nncf.torch import load_state
from nncf.torch.dynamic_graph.graph_tracer import create_input_infos
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.initialization import register_default_init_args
from nncf.torch.utils import is_main_process

Expand Down Expand Up @@ -363,8 +363,8 @@ def create_train_data_loader(batch_size):


def create_model(config: SampleConfig):
input_info_list = create_input_infos(config.nncf_config)
image_size = input_info_list[0].shape[-1]
input_info = FillerInputInfo.from_nncf_config(config.nncf_config)
image_size = input_info[0].shape[-1]
ssd_net = build_ssd(config.model, config.ssd_params, image_size, config.num_classes, config)
weights = config.get("weights")
if weights:
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self._model = model
self._input_names = input_names
self._output_names = output_names
self._model_args = model_args if model_args else ({},)
self._model_args = model_args

@abstractmethod
def export_model(self, save_path: str, save_format: Optional[str] = None) -> None:
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(
self.to_node = to_node
self.input_port_id = input_port_id
self.output_port_id = output_port_id
self.tensor_shape = tensor_shape
self.tensor_shape: Tuple[int] = tuple(tensor_shape)
self.dtype = dtype
self.parallel_input_port_ids = parallel_input_port_ids

Expand Down
2 changes: 1 addition & 1 deletion nncf/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_extra_struct(self, struct_cls: Type[NNCFExtraConfigStruct]) -> NNCFExtra
def has_extra_struct(self, struct_cls: Type[NNCFExtraConfigStruct]) -> NNCFExtraConfigStruct:
return struct_cls.get_id() in self.__nncf_extra_structs

def get_all_extra_structs_for_copy(self) -> List[NNCFExtraConfigStruct]:
def get_all_extra_structs(self) -> List[NNCFExtraConfigStruct]:
return list(self.__nncf_extra_structs.values())

def get_redefinable_global_param_value_for_algo(self, param_name: str, algo_name: str) -> Optional[str]:
Expand Down
4 changes: 4 additions & 0 deletions nncf/config/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,7 @@ def validate_accuracy_aware_schema(config: NNCFConfig, params: Dict[str, object]
validate_accuracy_aware_schema(config, params)

return params


def has_input_info_field(config: NNCFConfig) -> bool:
return config.get("input_info") is not None
12 changes: 5 additions & 7 deletions nncf/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@
" during tracing and exporting.",
),
"keyword": with_attributes(
STRING, description="Keyword to be used when passing the tensor to the model's 'forward' method."
STRING,
description="Keyword to be used when passing the tensor to the model's 'forward' method - "
"leave unspecified to pass the corresponding argument as a positional arg.",
),
},
"additionalProperties": False,
Expand Down Expand Up @@ -105,11 +107,8 @@
"This information is used to build the internal graph representation "
"that is leveraged for proper compression functioning, and for "
"exporting the compressed model to an executable format.\n"
"For instance, in PyTorch a dummy tensor with a "
"corresponding shape and filler will be generated for each entry "
"and passed as a corresponding argument into the model's forward "
"method. Keywords can be specified for each entry - if left "
"unspecified, the dummy tensor will be passed as a positional arg.",
"If this field is unspecified, NNCF will try to deduce the input shapes and tensor types for the graph "
"building purposes based on dataloader objects that are passed to compression algorithms by the user.",
),
"target_device": with_attributes(
TARGET_DEVICE_SCHEMA,
Expand Down Expand Up @@ -148,7 +147,6 @@
),
"log_dir": with_attributes(STRING, description="Log directory for NNCF-specific logging outputs."),
},
"required": ["input_info"],
"$defs": REF_VS_ALGO_SCHEMA,
}

Expand Down
14 changes: 13 additions & 1 deletion nncf/torch/dynamic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import networkx as nx
import networkx.algorithms.isomorphism as iso
import torch
from torch import Tensor

from nncf import nncf_logger
Expand Down Expand Up @@ -608,6 +609,7 @@ def __init__(self):
self.match_manager = NodeManager(self._node_id_to_key_dict, self._nx_graph)
self._input_nncf_nodes = []
self._output_nncf_nodes = []
self._integer_input_nodes = []

def __eq__(self, other: "DynamicGraph"):
nm = iso.categorical_node_match(
Expand Down Expand Up @@ -638,7 +640,7 @@ def add_node(
op_address: OperationAddress,
tensor_metas: List[TensorMeta],
input_comparators_per_scope: List[Tuple[TensorMetaComparator, List[str]]],
inputs,
inputs: OperatorInput,
node_parameters: DynamicGraphNodeParameters,
) -> DynamicGraphNode:
node = self.match_manager.add_node(
Expand All @@ -650,6 +652,13 @@ def add_node(

if node.op_exec_context.operator_name == MODEL_INPUT_OP_NAME:
self._input_nncf_nodes.append(node)
# Currently the MODEL_INPUT_OP_NAME node is added when an input is wrapped as
# _ = nncf_model_input(input_tensor)
# so it is expected that there 0-th positional arg will be the torch.Tensor we need to inspect
tensor_input = inputs.op_args[0]
assert isinstance(tensor_input, torch.Tensor)
if tensor_input.dtype in (torch.int32, torch.int64, torch.long):
self._integer_input_nodes.append(node)

if node.op_exec_context.operator_name == MODEL_OUTPUT_OP_NAME:
self._output_nncf_nodes.append(node)
Expand All @@ -658,6 +667,9 @@ def add_node(
def get_input_nodes(self) -> List[DynamicGraphNode]:
return self._input_nncf_nodes

def is_integer_input_node(self, node: DynamicGraphNode) -> bool:
return node in self._integer_input_nodes

def get_output_nodes(self) -> List[DynamicGraphNode]:
return self._output_nncf_nodes

Expand Down
98 changes: 14 additions & 84 deletions nncf/torch/dynamic_graph/graph_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,87 +8,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar

import torch

from nncf.torch.dynamic_graph.context import TracingContext
from nncf.torch.dynamic_graph.graph import DynamicGraph
from nncf.torch.dynamic_graph.io_handling import ModelInputInfo
from nncf.torch.utils import get_model_device


class ModelInputInfo:
FILLER_TYPE_ONES = "ones"
FILLER_TYPE_ZEROS = "zeros"
FILLER_TYPE_RANDOM = "random"
FILLER_TYPES = [FILLER_TYPE_ONES, FILLER_TYPE_ZEROS, FILLER_TYPE_RANDOM]

def __init__(self, shape: List[int], type_str: str = "float", keyword=None, filler=None):
self.shape = shape
self.type = self._string_to_torch_type(type_str)
self.keyword = keyword
if filler is None:
self.filler = self.FILLER_TYPE_ONES
else:
self.filler = filler
if self.filler not in self.FILLER_TYPES:
raise RuntimeError("Unknown input filler type: {}".format(filler))

@staticmethod
def _string_to_torch_type(string):
if string == "long":
return torch.long
return torch.float32

@staticmethod
def torch_type_to_string(dtype: torch.dtype):
if dtype is torch.long:
return "long"
return "float"

def is_integer_input(self):
return self.type != torch.float32

def __eq__(self, other):
return self.type == other.type and self.keyword == other.keyword


def create_input_infos(config) -> Optional[List[ModelInputInfo]]:
input_infos = config.get("input_info")
if input_infos is None:
return input_infos
if isinstance(input_infos, dict):
return [
ModelInputInfo(
input_infos.get("sample_size"),
input_infos.get("type"),
input_infos.get("keyword"),
input_infos.get("filler"),
),
]
if isinstance(input_infos, list):
return [
ModelInputInfo(
info_dict.get("sample_size"), info_dict.get("type"), info_dict.get("keyword"), info_dict.get("filler")
)
for info_dict in input_infos
]
raise RuntimeError("Invalid input_infos specified in config - should be either dict or list of dicts")


def create_mock_tensor(input_info: ModelInputInfo, device: str):
args = {"size": input_info.shape, "dtype": input_info.type, "device": device}
if input_info.filler == ModelInputInfo.FILLER_TYPE_ZEROS:
return torch.zeros(**args)
if input_info.filler == ModelInputInfo.FILLER_TYPE_ONES:
return torch.ones(**args)
if input_info.filler == ModelInputInfo.FILLER_TYPE_RANDOM:
return torch.rand(**args)
raise RuntimeError


class GraphTracer:
def __init__(self, custom_forward_fn: Callable[[torch.nn.Module], Any]):
self.custom_forward_fn = custom_forward_fn
Expand Down Expand Up @@ -118,29 +48,29 @@ def trace_graph(
return context_to_use.graph


T = TypeVar("T")
WrapInputsFnType = Callable[[Tuple, Dict], Tuple[Tuple, Dict]]
WrapOutputsFnType = Callable[[T], T]


def create_dummy_forward_fn(
input_infos: List[ModelInputInfo],
with_input_tracing=False,
wrap_inputs_fn=None,
wrap_outputs_fn=None,
with_output_tracing=False,
input_info: ModelInputInfo,
with_input_tracing: bool = False,
wrap_inputs_fn: WrapInputsFnType = None,
wrap_outputs_fn: WrapOutputsFnType = None,
with_output_tracing: bool = False,
):
def default_dummy_forward_fn(model):
from nncf.torch.dynamic_graph.io_handling import replicate_same_tensors
from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk
from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_outputs_with_objwalk

device = get_model_device(model)
args_list = [create_mock_tensor(info, device) for info in input_infos if info.keyword is None]
kwargs = OrderedDict()
for info in input_infos:
if info.keyword is not None:
kwargs[info.keyword] = create_mock_tensor(info, device)
args = tuple(args_list)
args, kwargs = input_info.get_forward_inputs(device=str(device))

if with_input_tracing:
if wrap_inputs_fn is None:
# We control the input argument structure w.r.t. tensors
# We control the input argument structure w.r.t. tensors if input_info is a FillerInputInfo
# - a simple objwalk application should be sufficient in this simple case.
# For more control, wrap_inputs_fn is used when this is used in NNCFNetwork
# which is guaranteed to be the same as during the actual NNCFNetwork.forward
Expand Down
Loading

0 comments on commit e00f6b7

Please sign in to comment.