Skip to content

Commit

Permalink
upd dump_statistics_to_dir signature
Browse files Browse the repository at this point in the history
  • Loading branch information
kshpv committed Dec 13, 2024
1 parent 1fcaf1d commit c79fe3b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def dump_statistics(self, dir_path: Path) -> None:
:param dir_path: The path of the directory where the statistics will be saved.
"""
data_to_dump = self._prepare_statistics()
metadata = {"backend": self.BACKEND.value, "subset_size": self.stat_subset_size}
dump_statistics_to_dir(data_to_dump, dir_path, metadata)
additional_metadata = {"subset_size": self.stat_subset_size}
dump_statistics_to_dir(data_to_dump, dir_path, self.BACKEND, additional_metadata)
nncf_logger.info(f"Statistics were successfully saved to a directory {dir_path.absolute()}")

def _prepare_statistics(self) -> Dict[str, Any]:
Expand Down
4 changes: 3 additions & 1 deletion nncf/common/tensor_statistics/statistics_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def load_statistics_from_dir(dir_path: Path, backend: BackendType) -> Dict[str,
def dump_statistics_to_dir(
statistics: Dict[str, Dict[str, Tensor]],
dir_path: Path,
backend: BackendType,
additional_metadata: Dict[str, Any],
) -> None:
"""
Expand All @@ -121,10 +122,11 @@ def dump_statistics_to_dir(
:param statistics: A dictionary with statistic names as keys and the statistic data as values.
:param dir_path: The path to the directory where the statistics will be dumped.
:param backend: Backend type to save in metadata.
:param additional_metadata: A dictionary containing any additional metadata to be saved with the mapping.
"""
dir_path.mkdir(parents=True, exist_ok=True)
metadata: Dict[str, Any] = {"mapping": {}}
metadata: Dict[str, Any] = {"mapping": {}, "backend": backend.value}
unique_map: Dict[str, List[str]] = defaultdict(list)
for original_name, statistics_value in statistics.items():
sanitized_name = sanitize_filename(original_name)
Expand Down
4 changes: 2 additions & 2 deletions tests/cross_fw/test_templates/test_statistics_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def test_load_no_statistics_file(self, tmp_path):
def test_dump_and_load_statistics(self, tmp_path):
backend = self._get_backend()
statistics = self._get_backend_statistics()
additional_metadata = {"model": "facebook/opt-125m", "compression": "8-bit", "backend": backend.value}
additional_metadata = {"model": "facebook/opt-125m", "compression": "8-bit"}

dump_statistics_to_dir(statistics, tmp_path, additional_metadata)
dump_statistics_to_dir(statistics, tmp_path, backend, additional_metadata)

assert len(list(Path(tmp_path).iterdir())) > 0, "No files created during dumping"

Expand Down

0 comments on commit c79fe3b

Please sign in to comment.