Skip to content

Commit

Permalink
Corrected the docstring of TuningConfig (#1639)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: Kaihui-intel <[email protected]>
  • Loading branch information
yiliu30 and Kaihui-intel authored Feb 28, 2024
1 parent 2b86e50 commit 853dc71
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions neural_compressor/common/base_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@


import copy
import inspect
import uuid
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sized, Tuple, Union

from neural_compressor.common import Logger
from neural_compressor.common.base_config import BaseConfig, ComposableConfig
from neural_compressor.common.base_config import BaseConfig
from neural_compressor.common.utils import TuningLogger

logger = Logger().get_logger()
Expand Down Expand Up @@ -227,19 +226,11 @@ def __iter__(self) -> Generator[BaseConfig, Any, None]:


class TuningConfig:
"""Base Class for Tuning Criterion.
Args:
config_set: quantization configs. Default value is empty.
A single config or a list of configs. More details can
be found in the `from_fwk_configs`of `ConfigSet` class.
max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit.
tolerable_loss: This float indicates how much metric loss we can accept. \
The metric loss is relative, it can be both positive and negative. Default is 0.01.
"""Config for auto tuning pipeline.
Examples:
# TODO: to refine it
from neural_compressor import TuningConfig
from neural_compressor.torch.quantization import TuningConfig
tune_config = TuningConfig(
config_set=[config1, config2, ...],
max_trials=3,
Expand All @@ -264,13 +255,25 @@ class TuningConfig:
"""

def __init__(
self, config_set=None, max_trials=100, sampler: Sampler = default_sampler, tolerable_loss=0.01
) -> None:
"""Init a TuneCriterion object."""
self,
config_set: Union[BaseConfig, List[BaseConfig]] = None,
sampler: Sampler = default_sampler,
tolerable_loss=0.01,
max_trials=100,
):
"""Initial a TuningConfig.
Args:
config_set: A single config or a list of configs. Defaults to None.
sampler: tuning sampler that decide the trials order. Defaults to default_sampler.
tolerable_loss: This float indicates how much metric loss we can accept.
The metric loss is relative, it can be both positive and negative. Default is 0.01.
max_trials: Max tuning times. Combine with `tolerable_loss` field to decide when to stop. Default is 100.
"""
self.config_set = config_set
self.max_trials = max_trials
self.sampler = sampler
self.tolerable_loss = tolerable_loss
self.max_trials = max_trials


class _TrialRecord:
Expand Down

0 comments on commit 853dc71

Please sign in to comment.