-
Notifications
You must be signed in to change notification settings - Fork 31
/
test_config_templates.py
87 lines (73 loc) · 2.66 KB
/
test_config_templates.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
from pathlib import Path
import pytest
import requests
import yaml
from gretel_client import ClientConfig, configure_session
from gretel_client.projects import Project, tmp_project
from gretel_client.projects.models import read_model_config
_api_key = os.getenv("GRETEL_API_KEY")
if not _api_key:
raise RuntimeError("API key env not set")
_cloud_url = os.getenv("GRETEL_CLOUD_URL")
if not _cloud_url:
raise RuntimeError("Gretel cloud url env not set")
configure_session(ClientConfig(endpoint=_cloud_url, api_key=_api_key))
@pytest.fixture(scope="session")
def project() -> Project:
with tmp_project() as project:
yield project
_configs = list((Path(__file__).parent / "config_templates").glob("**/*.yml"))
@pytest.mark.parametrize(
"_config_file",
[
"/".join(str(_config).split("/")[-4:])
for _config in _configs
if _config.parent.name != "tuner"
],
)
def test_configs(_config_file, project: Project):
_config_dict = yaml.safe_load(open(_config_file).read())
resp = requests.post(
f"{_cloud_url}/projects/{project.name}/models",
json=_config_dict,
params={"dry_run": "yes"},
headers={"Authorization": _api_key},
)
if resp.status_code != 200:
print(f"Error for {_cloud_url}, got response: {resp.text}")
assert resp.status_code == 200
@pytest.mark.parametrize(
"_config_file",
[
"/".join(str(_config).split("/")[-4:])
for _config in _configs
if _config.parent.name == "tuner"
],
)
def test_tuner_configs(_config_file, project: Project):
tuner_config_dict = yaml.safe_load(open(_config_file).read())
tuner_config_dict.pop("metric")
base_config = tuner_config_dict.pop("base_config")
config = read_model_config(base_config)
model_config = next(iter(config["models"][0].values()))
# update the model config with the tuner params
for section, section_params in tuner_config_dict.items():
assert section in model_config
for name, options in section_params.items():
# tuner param options are always list-like
value = next(iter(options.values()))[0]
if name in model_config[section]:
model_config[section][name] = value
else:
model_config[section].setdefault(name, value)
# execute dry run via the API
resp = requests.post(
f"{_cloud_url}/projects/{project.name}/models",
json=config,
params={"dry_run": "yes"},
headers={"Authorization": _api_key},
)
if resp.status_code != 200:
print(f"Error for {_cloud_url}, got response: {resp.text}")
assert resp.status_code == 200