diff --git a/nni/experiment/config/experiment_config.py b/nni/experiment/config/experiment_config.py index c82aba381b..74b1093b14 100644 --- a/nni/experiment/config/experiment_config.py +++ b/nni/experiment/config/experiment_config.py @@ -8,7 +8,9 @@ __all__ = ['ExperimentConfig'] from dataclasses import dataclass +import json import logging +from pathlib import Path from typing import Any, List, Optional, Union import yaml @@ -113,6 +115,16 @@ def _canonicalize(self, _parents): super()._canonicalize([self]) + if self.search_space_file is not None: + yaml_error = None + try: + self.search_space = _load_search_space_file(self.search_space_file) + except Exception as e: + yaml_error = repr(e) + if yaml_error is not None: # raise it outside except block to make stack trace clear + msg = f'ExperimentConfig: Failed to load search space file "{self.search_space_file}": {yaml_error}' + raise ValueError(msg) + if self.nni_manager_ip is None: # show a warning if user does not set nni_manager_ip. we have many issues caused by this # the simple detection logic won't work for hybrid, but advanced users should not need it @@ -133,10 +145,6 @@ def _validate_canonical(self): if not self.use_annotation and space_cnt < 1: raise ValueError('ExperimentConfig: search_space and search_space_file must be set one') - if self.search_space_file is not None: - with open(self.search_space_file) as ss_file: - self.search_space = yaml.safe_load(ss_file) - # to make the error message clear, ideally it should be: # `if concurrency < 0: raise ValueError('trial_concurrency ({concurrency}) must greater than 0')` # but I believe there will be hardy few users make this kind of mistakes, so let's keep it simple @@ -156,3 +164,13 @@ def _validate_canonical(self): tuner_cnt = (self.tuner is not None) + (self.advisor is not None) if tuner_cnt != 1: raise ValueError('ExperimentConfig: tuner and advisor must be set one') + +def _load_search_space_file(search_space_path): + # FIXME + # we need this because PyYAML 6.0 does not support YAML 1.2, + # which means it is not fully compatible with JSON + content = Path(search_space_path).read_text(encoding='utf8') + try: + return json.loads(content) + except Exception: + return yaml.safe_load(content) diff --git a/test/ut/experiment/assets/ss.yaml b/test/ut/experiment/assets/ss.yaml new file mode 100644 index 0000000000..90912d52ce --- /dev/null +++ b/test/ut/experiment/assets/ss.yaml @@ -0,0 +1,9 @@ +pool_type: + _type: choice + _value: + - max + - min + - avg +学习率: # test unicode + _type: loguniform + _value: [ 0.0000001, 0.1 ] diff --git a/test/ut/experiment/assets/ss_comma.json b/test/ut/experiment/assets/ss_comma.json new file mode 100644 index 0000000000..ccc5c0f4fb --- /dev/null +++ b/test/ut/experiment/assets/ss_comma.json @@ -0,0 +1,10 @@ +{ + "pool_type": { + "_type": "choice", + "_value": [ "max", "min", "avg" ], + }, + "学习率": { + "_type": "loguniform", + "_value": [ 0.0000001, 0.1 ], + }, +} diff --git a/test/ut/experiment/assets/ss_tab.json b/test/ut/experiment/assets/ss_tab.json new file mode 100644 index 0000000000..b5d0f659fa --- /dev/null +++ b/test/ut/experiment/assets/ss_tab.json @@ -0,0 +1,10 @@ +{ + "pool_type": { + "_type": "choice", + "_value": [ "max", "min", "avg" ] + }, + "学习率": { + "_type": "loguniform", + "_value": [ 1e-7, 0.1 ] + } +} diff --git a/test/ut/experiment/assets/ss_tab_comma.json b/test/ut/experiment/assets/ss_tab_comma.json new file mode 100644 index 0000000000..4d7650c55e --- /dev/null +++ b/test/ut/experiment/assets/ss_tab_comma.json @@ -0,0 +1,10 @@ +{ + "pool_type": { + "_type": "choice", + "_value": [ "max", "min", "avg" ], + }, + "学习率": { + "_type": "loguniform", + "_value": [ 1e-7, 0.1 ], + }, +} diff --git a/test/ut/experiment/assets/ss_yaml12.yaml b/test/ut/experiment/assets/ss_yaml12.yaml new file mode 100644 index 0000000000..ec53c52218 --- /dev/null +++ b/test/ut/experiment/assets/ss_yaml12.yaml @@ -0,0 +1,9 @@ +pool_type: + _type: choice + _value: + - max + - min + - avg +学习率: # test unicode + _type: loguniform + _value: [ 1e-7, 0.1 ] # test scientific notation diff --git a/test/ut/experiment/test_search_space.py b/test/ut/experiment/test_search_space.py new file mode 100644 index 0000000000..6df630f61f --- /dev/null +++ b/test/ut/experiment/test_search_space.py @@ -0,0 +1,52 @@ +import json +from pathlib import Path + +import yaml + +from nni.experiment.config import ExperimentConfig, AlgorithmConfig, LocalConfig + +## template ## + +config = ExperimentConfig( + search_space_file = '', + trial_command = 'echo hello', + trial_concurrency = 1, + tuner = AlgorithmConfig(name='randomm'), + training_service = LocalConfig() +) + +space_correct = { + 'pool_type': { + '_type': 'choice', + '_value': ['max', 'min', 'avg'] + }, + '学习率': { + '_type': 'loguniform', + '_value': [1e-7, 0.1] + } +} + +# FIXME +# PyYAML 6.0 (YAML 1.1) does not support tab and scientific notation +# JSON does not support comment and extra comma +# So some combinations will fail to load +formats = [ + ('ss_tab.json', 'JSON (tabs + scientific notation)'), + ('ss_comma.json', 'JSON with extra comma'), + #('ss_tab_comma.json', 'JSON (tabs + scientific notation) with extra comma'), + ('ss.yaml', 'YAML'), + #('ss_yaml12.yaml', 'YAML 1.2 with scientific notation'), +] + +def test_search_space(): + for space_file, description in formats: + try: + config.search_space_file = Path(__file__).parent / 'assets' / space_file + space = config.json()['searchSpace'] + assert space == space_correct + except Exception as e: + print('Failed to load search space format: ' + description) + raise e + +if __name__ == '__main__': + test_search_space()