Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Apr 25, 2024
1 parent 9e5197d commit fa077fa
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 148 deletions.
20 changes: 0 additions & 20 deletions nncf/torch/dynamic_graph/io_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def __init__(self, shape: List[int], type_str: str = "float", keyword: str = Non
"""
self.shape = shape
self.type = self._string_to_torch_type(type_str)
self._type_str = type_str
self.keyword = keyword
if filler is None:
self.filler = self.FILLER_TYPE_ONES
Expand Down Expand Up @@ -158,15 +157,6 @@ def get_tensor_for_input(self) -> torch.Tensor:
return torch.rand(size=self.shape, dtype=self.type)
raise NotImplementedError

def get_state(self) -> Dict[str, Any]:
return {"shape": self.shape, "type_str": self._type_str, "keyword": self.keyword, "filler": self.filler}

@classmethod
def from_state(cls, state: Dict[str, Any]) -> "FillerInputElement":
return FillerInputElement(
shape=state["shape"], type_str=state["type_str"], keyword=state["keyword"], filler=state["filler"]
)


class FillerInputInfo(ModelInputInfo):
"""
Expand Down Expand Up @@ -230,16 +220,6 @@ def get_forward_inputs(
kwargs[fe.keyword] = tensor
return tuple(args_list), kwargs

def get_state(self) -> Dict[str, Any]:
return {"elements": [elem.get_state() for elem in self.elements]}

@classmethod
def from_state(cls, state) -> "FillerInputInfo":
return FillerInputInfo([FillerInputElement.from_state(s) for s in state["elements"]])

def __eq__(self, other: "FillerInputInfo") -> bool:
return self.elements == other.elements


class ExactInputsInfo(ModelInputInfo):
"""
Expand Down
70 changes: 45 additions & 25 deletions nncf/torch/graph/transformations/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
from nncf.torch.layer_utils import COMPRESSION_MODULES

COMPRESSION_STATE_ATTR = "compression_state"


class CompressionKeys(Enum):
SHARED_INSERTION_COMMAND = "SHARED_INSERTION_COMMAND"
INSERTION_COMMAND = "INSERTION_COMMAND"
SUPPORTED_COMMANDS = (PTSharedFnInsertionCommand, PTInsertionCommand)


def serialize_transformations(transformations_layout: TransformationLayout) -> Dict[str, Any]:
"""
Serializes given transformation layout to a dict.
:param tranformation_layout: Given transformation layout.
:return: Serialized representation of given transformation layout as a dict.
"""
transformation_commands = []
for command in transformations_layout.transformations:
serialized_command = serialize_command(command)
Expand All @@ -39,28 +41,39 @@ def serialize_transformations(transformations_layout: TransformationLayout) -> D
return {COMPRESSION_STATE_ATTR: transformation_commands}


def load_transformations(transformations_state: Dict[str, Any]) -> TransformationLayout:
def deserialize_transformations(serialized_transformation_layout: Dict[str, Any]) -> TransformationLayout:
"""
Deserializes given serialized transformation layout.
:param serialized_transformation_layout: Given serialized transformation layout.
:return: The deserialized transformation layout.
"""
transformation_layout = TransformationLayout()
for serialized_command in transformations_state[COMPRESSION_STATE_ATTR]:
command = load_command(serialized_command)
for serialized_command in serialized_transformation_layout[COMPRESSION_STATE_ATTR]:
command = deserialize_command(serialized_command)
transformation_layout.register(command)

return transformation_layout


def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]:
if not isinstance(command, (PTSharedFnInsertionCommand, PTInsertionCommand)):
return {}
"""
Serializes given command layout to a dict.
:param command: Given command.
:return: Serialized representation of given command as a dict.
"""
if not isinstance(command, SUPPORTED_COMMANDS):
raise RuntimeError(f"Command type {command.__class__} is not supported.")

serialized_transformation = dict()
serialized_transformation["type"] = command.__class__.__name__
if isinstance(command, PTSharedFnInsertionCommand):
serialized_transformation["type"] = CompressionKeys.SHARED_INSERTION_COMMAND.value
serialized_transformation["target_points"] = [point.get_state() for point in command.target_points]
serialized_transformation["op_name"] = command.op_name
serialized_transformation["compression_module_type"] = command.compression_module_type.value

elif isinstance(command, PTInsertionCommand):
serialized_transformation["type"] = CompressionKeys.INSERTION_COMMAND.value
serialized_transformation["target_point"] = command.target_point.get_state()

# Check compression module is registered
Expand All @@ -78,27 +91,34 @@ def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]:
return serialized_transformation


def load_command(serialized_command: Dict[str, Any]) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]:
def deserialize_command(serialized_command: Dict[str, Any]) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]:
"""
Deserializes given serialized command.
:param serialized_command: Given serialized command.
:return: The deserialized command.
"""
if serialized_command["type"] not in (command_cls.__name__ for command_cls in SUPPORTED_COMMANDS):
raise RuntimeError(f"Command type {serialized_command['type']} is not supported.")

module_cls = COMPRESSION_MODULES.get(serialized_command["compression_module_name"])
fn = module_cls.from_state(serialized_command["fn_state"])
priority = serialized_command["priority"]
if priority in iter(TransformationPriority):
priority = TransformationPriority(priority)

if serialized_command["type"] == CompressionKeys.INSERTION_COMMAND.value:
if serialized_command["type"] == PTInsertionCommand.__name__:
target_point = PTTargetPoint.from_state(serialized_command["target_point"])
return PTInsertionCommand(
point=target_point, fn=fn, priority=priority, hooks_group_name=serialized_command["hooks_group_name"]
)

if serialized_command["type"] == CompressionKeys.SHARED_INSERTION_COMMAND.value:
target_points = [PTTargetPoint.from_state(state) for state in serialized_command["target_points"]]
return PTSharedFnInsertionCommand(
target_points=target_points,
fn=fn,
op_unique_name=serialized_command["op_name"],
compression_module_type=ExtraCompressionModuleType(serialized_command["compression_module_type"]),
priority=priority,
hooks_group_name=serialized_command["hooks_group_name"],
)
raise RuntimeError(f"Command type {serialized_command['type']} is not supported.")
target_points = [PTTargetPoint.from_state(state) for state in serialized_command["target_points"]]
return PTSharedFnInsertionCommand(
target_points=target_points,
fn=fn,
op_unique_name=serialized_command["op_name"],
compression_module_type=ExtraCompressionModuleType(serialized_command["compression_module_type"]),
priority=priority,
hooks_group_name=serialized_command["hooks_group_name"],
)
1 change: 0 additions & 1 deletion nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self, model: NNCFNetwork):
]

def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwork:
# self._model.nncf.record_commands(transformation_layout.transformations)
transformations = transformation_layout.transformations
aggregated_transformations = defaultdict(list)
requires_graph_rebuild = False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
{
"TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": {
"input_low": 0.0,
"input_high": 0.9970665574073792
"input_low": [
0.0
],
"input_high": [
0.9970665574073792
]
},
"TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|OUTPUT]": {
"input_low": -3.8243322372436523,
"input_high": 3.794454574584961
"input_low": [
-3.8243322372436523
],
"input_high": [
3.794454574584961
]
},
"TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|INPUT1]": {
"input_low": [
[
[
[
-1
-1.0
]
]
],
[
[
[
-1
-1.0
]
]
]
Expand All @@ -28,14 +36,14 @@
[
[
[
1
1.0
]
]
],
[
[
[
1
1.0
]
]
]
Expand All @@ -46,7 +54,7 @@
[
[
[
-1
-1.0
]
]
]
Expand All @@ -55,26 +63,10 @@
[
[
[
1
1.0
]
]
]
]
},
"TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": {
"input_low": [
0
],
"input_high": [
0.9800970554351807
]
},
"TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0|OUTPUT]": {
"input_low": [
-3.8243322372436523
],
"input_high": [
3.794454574584961
]
}
}
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
{
"TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": {
"input_low": 0.0,
"input_high": 0.9970665574073792
"input_low": [
0.0
],
"input_high": [
0.9970665574073792
]
},
"TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|OUTPUT]": {
"input_low": -3.8243322372436523,
"input_high": 3.794454574584961
"input_low": [
-3.8243322372436523
],
"input_high": [
3.794454574584961
]
},
"TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|INPUT1]": {
"input_low": [
[
[
[
-2
-2.0
]
]
],
[
[
[
-2
-2.0
]
]
]
Expand All @@ -28,14 +36,14 @@
[
[
[
2
2.0
]
]
],
[
[
[
2
2.0
]
]
]
Expand All @@ -46,7 +54,7 @@
[
[
[
-2
-2.0
]
]
]
Expand All @@ -55,26 +63,10 @@
[
[
[
2
2.0
]
]
]
]
},
"TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": {
"input_low": [
0
],
"input_high": [
0.9800970554351807
]
},
"TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0|OUTPUT]": {
"input_low": [
-3.8243322372436523
],
"input_high": [
3.794454574584961
]
}
}
}
Loading

0 comments on commit fa077fa

Please sign in to comment.