Skip to content

Commit

Permalink
Add baseline results for HPO problems (#920)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #920

Added baseline results for MNIST and FASHION_MNIST HPO problems. Also genericized the way we load problem-baseline pairs.

Reviewed By: dme65

Differential Revision: D35729446

fbshipit-source-id: 1aabf16eb517dd471bec87351fbfb4c45d5e08f4
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Apr 20, 2022
1 parent 88fe301 commit e410c3f
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 75 deletions.
4 changes: 4 additions & 0 deletions ax/benchmark/problems/baseline_results/hpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"__type": "AggregatedBenchmarkResult",
"name": "HPO_PyTorchCNN_Torchvision::FashionMNIST|SOBOL_BASELINE_1650293655",
"experiments": [],
"optimization_trace": {
"__type": "DataFrame",
"value": "{\"median\":{\"0\":0.1,\"1\":0.1,\"2\":0.1,\"3\":0.1,\"4\":0.1,\"5\":0.1,\"6\":0.1,\"7\":0.1,\"8\":0.1,\"9\":0.1018,\"10\":0.10625,\"11\":0.11685,\"12\":0.1337,\"13\":0.158,\"14\":0.1681,\"15\":0.17055,\"16\":0.18365,\"17\":0.18955,\"18\":0.18955,\"19\":0.19545,\"20\":0.1998,\"21\":0.20345,\"22\":0.21035,\"23\":0.2121,\"24\":0.2121,\"25\":0.2193,\"26\":0.22695,\"27\":0.23825,\"28\":0.24185,\"29\":0.24185,\"30\":0.2735,\"31\":0.281,\"32\":0.281,\"33\":0.28275,\"34\":0.29745,\"35\":0.32205,\"36\":0.32205,\"37\":0.3393,\"38\":0.35655,\"39\":0.35655,\"40\":0.35655,\"41\":0.35655,\"42\":0.37925,\"43\":0.41735,\"44\":0.41735,\"45\":0.44535,\"46\":0.4579,\"47\":0.47825,\"48\":0.47825,\"49\":0.50565,\"50\":0.51565,\"51\":0.52115,\"52\":0.52115,\"53\":0.5257,\"54\":0.5257,\"55\":0.5333,\"56\":0.5421,\"57\":0.5447,\"58\":0.5463,\"59\":0.54675,\"60\":0.5481,\"61\":0.55055,\"62\":0.55055,\"63\":0.55755,\"64\":0.5666,\"65\":0.56785,\"66\":0.5704,\"67\":0.57355,\"68\":0.57355,\"69\":0.5753,\"70\":0.5757,\"71\":0.5757,\"72\":0.5757,\"73\":0.5757,\"74\":0.5757,\"75\":0.5769,\"76\":0.5785,\"77\":0.5785,\"78\":0.5785,\"79\":0.5795,\"80\":0.5795,\"81\":0.5795,\"82\":0.5795,\"83\":0.58055,\"84\":0.5865,\"85\":0.5865,\"86\":0.5878,\"87\":0.5878,\"88\":0.5878,\"89\":0.5889,\"90\":0.59615,\"91\":0.60365,\"92\":0.60365,\"93\":0.60365,\"94\":0.60365,\"95\":0.6097,\"96\":0.6097,\"97\":0.6097,\"98\":0.61005,\"99\":0.61005},\"mean\":{\"0\":0.107524,\"1\":0.113011,\"2\":0.117768,\"3\":0.137795,\"4\":0.140605,\"5\":0.14652,\"6\":0.152092,\"7\":0.153522,\"8\":0.162272,\"9\":0.172777,\"10\":0.189701,\"11\":0.215244,\"12\":0.229722,\"13\":0.232523,\"14\":0.239341,\"15\":0.243123,\"16\":0.260705,\"17\":0.278672,\"18\":0.278854,\"19\":0.292155,\"20\":0.301498,\"21\":0.303506,\"22\":0.30569,\"23\":0.306362,\"24\":0.306362,\"25\":0.31078,\"26\":0.313843,\"27\":0.321665,\"28\":0.325659,\"29\":0.327403,\"30\":0.341891,\"31\":0.345857,\"32\":0.34849,\"33\":0.349448,\"34\":0.35167,\"35\":0.358033,\"36\":0.361194,\"37\":0.364868,\"38\":0.373459,\"39\":0.373459,\"40\":0.373459,\"41\":0.374708,\"42\":0.387815,\"43\":0.395196,\"44\":0.396617,\"45\":0.399121,\"46\":0.402311,\"47\":0.408577,\"48\":0.411614,\"49\":0.418981,\"50\":0.427924,\"51\":0.430873,\"52\":0.431015,\"53\":0.44499,\"54\":0.445328,\"55\":0.458161,\"56\":0.46174,\"57\":0.470881,\"58\":0.471613,\"59\":0.481034,\"60\":0.4883,\"61\":0.489044,\"62\":0.49232,\"63\":0.504968,\"64\":0.517859,\"65\":0.521639,\"66\":0.528044,\"67\":0.530059,\"68\":0.53152,\"69\":0.53816,\"70\":0.538781,\"71\":0.538781,\"72\":0.539454,\"73\":0.540154,\"74\":0.540307,\"75\":0.541903,\"76\":0.543266,\"77\":0.543266,\"78\":0.543266,\"79\":0.547483,\"80\":0.548481,\"81\":0.548481,\"82\":0.549044,\"83\":0.552186,\"84\":0.562209,\"85\":0.563533,\"86\":0.569119,\"87\":0.573102,\"88\":0.573102,\"89\":0.574608,\"90\":0.577508,\"91\":0.579382,\"92\":0.579382,\"93\":0.579382,\"94\":0.579382,\"95\":0.582356,\"96\":0.587131,\"97\":0.590691,\"98\":0.592098,\"99\":0.592797},\"sem\":{\"0\":0.0060621579,\"1\":0.0064993247,\"2\":0.0070555215,\"3\":0.0109174311,\"4\":0.0110927262,\"5\":0.0110822591,\"6\":0.0114919935,\"7\":0.0115402539,\"8\":0.0124269807,\"9\":0.0132780931,\"10\":0.0147192499,\"11\":0.0176739101,\"12\":0.0189090479,\"13\":0.0188187504,\"14\":0.0189664656,\"15\":0.0189059507,\"16\":0.0195886675,\"17\":0.0202922818,\"18\":0.0202769053,\"19\":0.020396165,\"20\":0.0205667946,\"21\":0.0206015783,\"22\":0.0205551164,\"23\":0.0205108025,\"24\":0.0205108025,\"25\":0.0204872303,\"26\":0.0204454726,\"27\":0.0204449775,\"28\":0.0203396224,\"29\":0.0201999296,\"30\":0.0201698607,\"31\":0.0201741783,\"32\":0.0203076604,\"33\":0.0202529938,\"34\":0.0202079418,\"35\":0.0202848889,\"36\":0.0206296387,\"37\":0.0206217256,\"38\":0.0206126462,\"39\":0.0206126462,\"40\":0.0206126462,\"41\":0.0204827044,\"42\":0.020576227,\"43\":0.0204912784,\"44\":0.0206838043,\"45\":0.0206391426,\"46\":0.0207792183,\"47\":0.0207153645,\"48\":0.020675578,\"49\":0.0209120355,\"50\":0.0209042021,\"51\":0.0208892251,\"52\":0.0208994643,\"53\":0.0206891109,\"54\":0.0206443485,\"55\":0.0201230466,\"56\":0.0200788944,\"57\":0.0197713259,\"58\":0.0197984428,\"59\":0.0192940025,\"60\":0.0191556825,\"61\":0.0191753387,\"62\":0.018976051,\"63\":0.0189211704,\"64\":0.0186694195,\"65\":0.0184688394,\"66\":0.0179560299,\"67\":0.0180849005,\"68\":0.0177963071,\"69\":0.0175678391,\"70\":0.0175799973,\"71\":0.0175799973,\"72\":0.0174648638,\"73\":0.0173063717,\"74\":0.0172676992,\"75\":0.0173000037,\"76\":0.0173324947,\"77\":0.0173324947,\"78\":0.0173324947,\"79\":0.017706858,\"80\":0.017685449,\"81\":0.017685449,\"82\":0.0176707759,\"83\":0.0178870369,\"84\":0.0176152785,\"85\":0.0177029989,\"86\":0.0172398605,\"87\":0.0168005236,\"88\":0.0168005236,\"89\":0.0168487558,\"90\":0.0165488455,\"91\":0.0166139793,\"92\":0.0166139793,\"93\":0.0166139793,\"94\":0.0166139793,\"95\":0.0167067797,\"96\":0.0161106918,\"97\":0.0157508082,\"98\":0.0157995977,\"99\":0.0156524714}}"
},
"fit_time": [
2.228915958404541,
0.04985408353937597
],
"gen_time": [
0.19300081288383808,
0.010364870053888276
]
}
17 changes: 17 additions & 0 deletions ax/benchmark/problems/baseline_results/hpo/torchvision/mnist.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"__type": "AggregatedBenchmarkResult",
"name": "HPO_PyTorchCNN_Torchvision::MNIST|SOBOL_BASELINE_1650250692",
"experiments": [],
"optimization_trace": {
"__type": "DataFrame",
"value": "{\"median\":{\"0\":0.10925,\"1\":0.1135,\"2\":0.1135,\"3\":0.1135,\"4\":0.1135,\"5\":0.1135,\"6\":0.1135,\"7\":0.1135,\"8\":0.1135,\"9\":0.1135,\"10\":0.1135,\"11\":0.11885,\"12\":0.13135,\"13\":0.13135,\"14\":0.1353,\"15\":0.18545,\"16\":0.18645,\"17\":0.19165,\"18\":0.19745,\"19\":0.2023,\"20\":0.2047,\"21\":0.21035,\"22\":0.21695,\"23\":0.2219,\"24\":0.24725,\"25\":0.26465,\"26\":0.26575,\"27\":0.26575,\"28\":0.27125,\"29\":0.2804,\"30\":0.2804,\"31\":0.31635,\"32\":0.33245,\"33\":0.35745,\"34\":0.35745,\"35\":0.4591,\"36\":0.46855,\"37\":0.4965,\"38\":0.5425,\"39\":0.5425,\"40\":0.5727,\"41\":0.58285,\"42\":0.58285,\"43\":0.59055,\"44\":0.59055,\"45\":0.59055,\"46\":0.6085,\"47\":0.623,\"48\":0.6321,\"49\":0.6321,\"50\":0.65285,\"51\":0.6695,\"52\":0.68945,\"53\":0.6974,\"54\":0.7064,\"55\":0.7131,\"56\":0.7131,\"57\":0.7131,\"58\":0.7131,\"59\":0.7174,\"60\":0.73595,\"61\":0.73595,\"62\":0.75415,\"63\":0.7561,\"64\":0.7561,\"65\":0.7582,\"66\":0.7582,\"67\":0.7582,\"68\":0.7582,\"69\":0.7582,\"70\":0.7601,\"71\":0.7601,\"72\":0.7601,\"73\":0.7675,\"74\":0.7675,\"75\":0.7783,\"76\":0.7783,\"77\":0.7867,\"78\":0.78805,\"79\":0.78805,\"80\":0.78805,\"81\":0.78805,\"82\":0.78805,\"83\":0.78805,\"84\":0.78805,\"85\":0.78805,\"86\":0.78805,\"87\":0.79895,\"88\":0.79895,\"89\":0.79895,\"90\":0.79895,\"91\":0.79895,\"92\":0.79895,\"93\":0.79895,\"94\":0.80585,\"95\":0.80585,\"96\":0.8096,\"97\":0.8096,\"98\":0.8096,\"99\":0.8096},\"mean\":{\"0\":0.111997,\"1\":0.13768,\"2\":0.156231,\"3\":0.187791,\"4\":0.200755,\"5\":0.206198,\"6\":0.208183,\"7\":0.219609,\"8\":0.223913,\"9\":0.227184,\"10\":0.229302,\"11\":0.258208,\"12\":0.276752,\"13\":0.283532,\"14\":0.287326,\"15\":0.307243,\"16\":0.309055,\"17\":0.315535,\"18\":0.326869,\"19\":0.335569,\"20\":0.336748,\"21\":0.349527,\"22\":0.351527,\"23\":0.364343,\"24\":0.377646,\"25\":0.395137,\"26\":0.398933,\"27\":0.399405,\"28\":0.404648,\"29\":0.408885,\"30\":0.409827,\"31\":0.436669,\"32\":0.456107,\"33\":0.468825,\"34\":0.469973,\"35\":0.490575,\"36\":0.498297,\"37\":0.508823,\"38\":0.515039,\"39\":0.519119,\"40\":0.524195,\"41\":0.535044,\"42\":0.536328,\"43\":0.548188,\"44\":0.548954,\"45\":0.548954,\"46\":0.558487,\"47\":0.562538,\"48\":0.572777,\"49\":0.576176,\"50\":0.582628,\"51\":0.591259,\"52\":0.60171,\"53\":0.607924,\"54\":0.61267,\"55\":0.619509,\"56\":0.619509,\"57\":0.619673,\"58\":0.624574,\"59\":0.629521,\"60\":0.638889,\"61\":0.641469,\"62\":0.647532,\"63\":0.657666,\"64\":0.658142,\"65\":0.659348,\"66\":0.664494,\"67\":0.667649,\"68\":0.667649,\"69\":0.667649,\"70\":0.672243,\"71\":0.674973,\"72\":0.674973,\"73\":0.68239,\"74\":0.68239,\"75\":0.686303,\"76\":0.68852,\"77\":0.695545,\"78\":0.697068,\"79\":0.6972,\"80\":0.697442,\"81\":0.697442,\"82\":0.707465,\"83\":0.708479,\"84\":0.709618,\"85\":0.710096,\"86\":0.710096,\"87\":0.712416,\"88\":0.712416,\"89\":0.716937,\"90\":0.720653,\"91\":0.722019,\"92\":0.723241,\"93\":0.723241,\"94\":0.731965,\"95\":0.731965,\"96\":0.735098,\"97\":0.735098,\"98\":0.735947,\"99\":0.736529},\"sem\":{\"0\":0.0028764251,\"1\":0.0123145319,\"2\":0.0153962952,\"3\":0.0204801864,\"4\":0.0216928831,\"5\":0.0217169482,\"6\":0.0216718036,\"7\":0.0227463455,\"8\":0.0227233591,\"9\":0.0227110735,\"10\":0.0226556015,\"11\":0.0248483052,\"12\":0.0261523134,\"13\":0.0266601738,\"14\":0.0265300315,\"15\":0.0269346561,\"16\":0.026832998,\"17\":0.0271819441,\"18\":0.0282261372,\"19\":0.028410789,\"20\":0.0283382819,\"21\":0.0284132713,\"22\":0.0282976369,\"23\":0.0286014939,\"24\":0.0284116127,\"25\":0.028921068,\"26\":0.0290950723,\"27\":0.0290600218,\"28\":0.0289165684,\"29\":0.0286995586,\"30\":0.0286169663,\"31\":0.029172965,\"32\":0.0296112252,\"33\":0.0299188124,\"34\":0.0298295786,\"35\":0.0298145944,\"36\":0.0297088415,\"37\":0.0295672579,\"38\":0.0294109093,\"39\":0.0290612088,\"40\":0.0291565769,\"41\":0.0293327956,\"42\":0.029212948,\"43\":0.0290442388,\"44\":0.0290803745,\"45\":0.0290803745,\"46\":0.0288501951,\"47\":0.028593242,\"48\":0.028264535,\"49\":0.0280836676,\"50\":0.0277371714,\"51\":0.027542836,\"52\":0.0276169248,\"53\":0.0273453149,\"54\":0.027110287,\"55\":0.0272742714,\"56\":0.0272742714,\"57\":0.0272440141,\"58\":0.02693928,\"59\":0.0267804031,\"60\":0.0265302448,\"61\":0.026400596,\"62\":0.0265576523,\"63\":0.0260572178,\"64\":0.0260062951,\"65\":0.0260478364,\"66\":0.0255961556,\"67\":0.0252231306,\"68\":0.0252231306,\"69\":0.0252231306,\"70\":0.0250182754,\"71\":0.0248516718,\"72\":0.0248516718,\"73\":0.0244339143,\"74\":0.0244339143,\"75\":0.0242392186,\"76\":0.0240540369,\"77\":0.0233609561,\"78\":0.0234706356,\"79\":0.0234834802,\"80\":0.023490753,\"81\":0.023490753,\"82\":0.0223375596,\"83\":0.0222468011,\"84\":0.0223271778,\"85\":0.0223506596,\"86\":0.0223506596,\"87\":0.0224750452,\"88\":0.0224750452,\"89\":0.021996052,\"90\":0.0214374398,\"91\":0.0212025531,\"92\":0.0210979595,\"93\":0.0210979595,\"94\":0.0205940329,\"95\":0.0205940329,\"96\":0.0206296771,\"97\":0.0206296771,\"98\":0.0206863537,\"99\":0.0206349264}}"
},
"fit_time": [
2.592710185050964,
0.06607956279295531
],
"gen_time": [
0.2792521435397066,
0.020075580536743538
]
}
2 changes: 1 addition & 1 deletion ax/benchmark/problems/hpo/pytorch_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def from_datasets(
runner = PyTorchCNNRunner(name=name, train_set=train_set, test_set=test_set)

return cls(
name=name,
name=f"HPO_PyTorchCNN_{name}",
optimal_value=optimal_value,
search_space=search_space,
optimization_config=optimization_config,
Expand Down
10 changes: 7 additions & 3 deletions ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import re
from typing import Any, Dict

from ax.benchmark.problems.hpo.pytorch_cnn import (
Expand All @@ -11,7 +12,7 @@
)
from ax.core.runner import Runner
from ax.exceptions.core import UserInputError
from ax.utils.common.typeutils import checked_cast
from ax.utils.common.typeutils import not_none, checked_cast

try: # We don't require TorchVision by default.
from torchvision import transforms, datasets
Expand Down Expand Up @@ -53,7 +54,7 @@ def from_dataset_name(cls, name: str) -> "PyTorchCNNTorchvisionBenchmarkProblem"
)

return cls(
name=problem.name,
name=f"HPO_PyTorchCNN_Torchvision::{name}",
search_space=problem.search_space,
optimization_config=problem.optimization_config,
runner=runner,
Expand All @@ -71,7 +72,10 @@ class PyTorchCNNTorchvisionRunner(PyTorchCNNRunner):
def serialize_init_args(cls, runner: Runner) -> Dict[str, Any]:
pytorch_cnn_runner = checked_cast(PyTorchCNNRunner, runner)

return {"name": pytorch_cnn_runner.name}
pattern = re.compile("(?<=::).*") # Extract the dataset name
dataset_name = not_none(pattern.search(pytorch_cnn_runner.name)).group()

return {"name": dataset_name}

@classmethod
def deserialize_init_args(cls, args: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
73 changes: 73 additions & 0 deletions ax/benchmark/problems/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
from dataclasses import dataclass
from typing import Any, Dict, Callable, Tuple

from ax.benchmark.benchmark_problem import (
MultiObjectiveBenchmarkProblem,
BenchmarkProblem,
SingleObjectiveBenchmarkProblem,
)
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult
from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem
from ax.storage.json_store.decoder import object_from_json
from botorch.test_functions.multi_objective import BraninCurrin
from botorch.test_functions.synthetic import Branin, Ackley


@dataclass
class BenchmarkProblemRegistryEntry:
factory_fn: Callable[..., BenchmarkProblem]
factory_kwargs: Dict[str, Any]
baseline_results_path: str


BENCHMARK_PROBLEM_REGISTRY = {
"ackley": BenchmarkProblemRegistryEntry(
factory_fn=SingleObjectiveBenchmarkProblem.from_botorch_synthetic,
factory_kwargs={"test_problem": Ackley()},
baseline_results_path="baseline_results/synthetic/ackley.json",
),
"branin": BenchmarkProblemRegistryEntry(
factory_fn=SingleObjectiveBenchmarkProblem.from_botorch_synthetic,
factory_kwargs={"test_problem": Branin()},
baseline_results_path="baseline_results/synthetic/branin.json",
),
"branin_currin": BenchmarkProblemRegistryEntry(
factory_fn=MultiObjectiveBenchmarkProblem.from_botorch_multi_objective,
factory_kwargs={"test_problem": BraninCurrin()},
baseline_results_path="baseline_results/synthetic/branin_currin.json",
),
"hpo_pytorch_cnn_MNIST": BenchmarkProblemRegistryEntry(
factory_fn=PyTorchCNNTorchvisionBenchmarkProblem.from_dataset_name,
factory_kwargs={"name": "MNIST"},
baseline_results_path="baseline_results/hpo/torchvision/mnist.json",
),
"hpo_pytorch_cnn_FashionMNIST": BenchmarkProblemRegistryEntry(
factory_fn=PyTorchCNNTorchvisionBenchmarkProblem.from_dataset_name,
factory_kwargs={"name": "FashionMNIST"},
baseline_results_path="baseline_results/hpo/torchvision/fashion_mnist.json",
),
}


def get_problem_and_baseline(
problem_name: str,
) -> Tuple[BenchmarkProblem, AggregatedBenchmarkResult]:
entry = BENCHMARK_PROBLEM_REGISTRY[problem_name]

problem = entry.factory_fn(**entry.factory_kwargs)

current_dir = os.path.dirname(__file__)
file_path = os.path.join(current_dir, entry.baseline_results_path)

with open(file=file_path) as file:
loaded = json.loads(file.read())
baseline_result = object_from_json(loaded)

return (problem, baseline_result)
55 changes: 0 additions & 55 deletions ax/benchmark/problems/synthetic.py

This file was deleted.

21 changes: 9 additions & 12 deletions ax/benchmark/tests/test_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
# LICENSE file in the root directory of this source tree.

from ax.benchmark.benchmark_result import AggregatedBenchmarkResult
from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem
from ax.benchmark.problems.synthetic import (
get_problem_and_baseline_from_botorch,
_REGISTRY,
from ax.benchmark.problems.registry import (
get_problem_and_baseline,
BENCHMARK_PROBLEM_REGISTRY,
)
from ax.utils.common.testutils import TestCase

Expand All @@ -16,13 +15,11 @@ class TestProblems(TestCase):
def test_load_baselines(self):

# Make sure the json parsing suceeded
for name in _REGISTRY.keys():
_problem, baseline = get_problem_and_baseline_from_botorch(
problem_name=name
)
for name in BENCHMARK_PROBLEM_REGISTRY.keys():
if "MNIST" in name:
continue # Skip these as they cause the test to take a long time

self.assertTrue(isinstance(baseline, AggregatedBenchmarkResult))
problem, baseline = get_problem_and_baseline(problem_name=name)

def test_pytorch_cnn(self):
# Just check data loading and construction succeeds
PyTorchCNNTorchvisionBenchmarkProblem.from_dataset_name(name="MNIST")
self.assertTrue(isinstance(baseline, AggregatedBenchmarkResult))
self.assertIn(problem.name, baseline.name)
4 changes: 3 additions & 1 deletion ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import re
from typing import Any, Dict, Type

from ax.benchmark.problems.hpo.torchvision import (
Expand Down Expand Up @@ -47,6 +48,7 @@
from ax.storage.botorch_modular_registry import CLASS_TO_REGISTRY
from ax.storage.transform_registry import TRANSFORM_REGISTRY
from ax.utils.common.serialization import serialize_init_args
from ax.utils.common.typeutils import not_none


def experiment_to_dict(experiment: Experiment) -> Dict[str, Any]:
Expand Down Expand Up @@ -540,5 +542,5 @@ def pytorch_cnn_torchvision_benchmark_problem_to_dict(
# unit tests for this in benchmark suite
return { # pragma: no cover
"__type": problem.__class__.__name__,
"name": problem.name,
"name": not_none(re.compile("(?<=::).*").search(problem.name)).group(),
}
Loading

0 comments on commit e410c3f

Please sign in to comment.