diff --git a/neural_compressor/common/base_tuning.py b/neural_compressor/common/base_tuning.py index ca8e9370623..352ac374ab4 100644 --- a/neural_compressor/common/base_tuning.py +++ b/neural_compressor/common/base_tuning.py @@ -14,12 +14,11 @@ import copy -import inspect import uuid from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sized, Tuple, Union from neural_compressor.common import Logger -from neural_compressor.common.base_config import BaseConfig, ComposableConfig +from neural_compressor.common.base_config import BaseConfig from neural_compressor.common.utils import TuningLogger logger = Logger().get_logger() @@ -227,19 +226,11 @@ def __iter__(self) -> Generator[BaseConfig, Any, None]: class TuningConfig: - """Base Class for Tuning Criterion. - - Args: - config_set: quantization configs. Default value is empty. - A single config or a list of configs. More details can - be found in the `from_fwk_configs`of `ConfigSet` class. - max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit. - tolerable_loss: This float indicates how much metric loss we can accept. \ - The metric loss is relative, it can be both positive and negative. Default is 0.01. + """Config for auto tuning pipeline. Examples: # TODO: to refine it - from neural_compressor import TuningConfig + from neural_compressor.torch.quantization import TuningConfig tune_config = TuningConfig( config_set=[config1, config2, ...], max_trials=3, @@ -264,13 +255,25 @@ class TuningConfig: """ def __init__( - self, config_set=None, max_trials=100, sampler: Sampler = default_sampler, tolerable_loss=0.01 - ) -> None: - """Init a TuneCriterion object.""" + self, + config_set: Union[BaseConfig, List[BaseConfig]] = None, + sampler: Sampler = default_sampler, + tolerable_loss=0.01, + max_trials=100, + ): + """Initial a TuningConfig. + + Args: + config_set: A single config or a list of configs. Defaults to None. + sampler: tuning sampler that decide the trials order. Defaults to default_sampler. + tolerable_loss: This float indicates how much metric loss we can accept. + The metric loss is relative, it can be both positive and negative. Default is 0.01. + max_trials: Max tuning times. Combine with `tolerable_loss` field to decide when to stop. Default is 100. + """ self.config_set = config_set - self.max_trials = max_trials self.sampler = sampler self.tolerable_loss = tolerable_loss + self.max_trials = max_trials class _TrialRecord: