Skip to content

Commit

Permalink
Add function for embedding problems in higher dimmensional space (#925)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #925

Add a function to add n dummy parameters to a benchmark problem's search space, effectively embedding the problem in a higher dimension. This will be useful of for benchmarking our HDBO methods (SAASBO, etc).

Reviewed By: dme65

Differential Revision: D35726713

fbshipit-source-id: 42774f8dbebb5294814075e38c3b3e774763584d
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Apr 20, 2022
1 parent e410c3f commit 751fa64
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 2 deletions.
54 changes: 54 additions & 0 deletions ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from ax.service.scheduler import SchedulerOptions
from ax.utils.common.constants import Keys
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from botorch.acquisition.multi_objective.monte_carlo import (
qNoisyExpectedHypervolumeImprovement,
)
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.gp_regression import FixedNoiseGP


Expand Down Expand Up @@ -72,3 +76,53 @@ def get_sobol_botorch_modular_default():
generation_strategy=generation_strategy,
scheduler_options=scheduler_options,
)


def get_sobol_botorch_modular_saas_fully_bayesian_gp_qnei():
generation_strategy = GenerationStrategy(
name="SOBOL+BOTORCH_MODULAR::default",
steps=[
GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3),
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs={
"surrogate": Surrogate(SaasFullyBayesianSingleTaskGP),
"botorch_acqf_class": qNoisyExpectedImprovement,
},
),
],
)

scheduler_options = SchedulerOptions(total_trials=30)

return BenchmarkMethod(
name="SOBOL+BOTORCH_MODULAR::SaasFullyBayesianSingleTaskGP_qNoisyExpectedImprovement", # noqa
generation_strategy=generation_strategy,
scheduler_options=scheduler_options,
)


def get_sobol_botorch_modular_saas_fully_bayesian_gp_qnehvi():
generation_strategy = GenerationStrategy(
name="SOBOL+BOTORCH_MODULAR::default",
steps=[
GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=3),
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs={
"surrogate": Surrogate(SaasFullyBayesianSingleTaskGP),
"botorch_acqf_class": qNoisyExpectedHypervolumeImprovement,
},
),
],
)

scheduler_options = SchedulerOptions(total_trials=30)

return BenchmarkMethod(
name="SOBOL+BOTORCH_MODULAR::SaasFullyBayesianSingleTaskGP_qNoisyExpectedHypervolumeImprovement", # noqa
generation_strategy=generation_strategy,
scheduler_options=scheduler_options,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"__type": "AggregatedBenchmarkResult",
"name": "BraninCurrin_30d|SOBOL_BASELINE_1650404269",
"experiments": [],
"optimization_trace": {
"__type": "DataFrame",
"value": "{\"median\":{\"0\":0.0,\"1\":0.0,\"2\":0.0,\"3\":0.0,\"4\":0.0,\"5\":0.0,\"6\":0.0,\"7\":0.0,\"8\":0.0,\"9\":0.0,\"10\":0.0,\"11\":0.0,\"12\":0.0,\"13\":0.0,\"14\":0.0,\"15\":0.0,\"16\":0.015555134,\"17\":0.015555134,\"18\":0.5546263456,\"19\":0.624468118,\"20\":0.8303934038,\"21\":1.3893842101,\"22\":3.4350678921,\"23\":5.5421590805,\"24\":6.6424357891,\"25\":6.6948719025,\"26\":7.1665885448,\"27\":7.5748984814,\"28\":7.793194294,\"29\":8.0489635468,\"30\":8.0489635468,\"31\":8.0489635468,\"32\":8.8560214043,\"33\":10.1740679741,\"34\":11.0346088409,\"35\":11.0346088409,\"36\":11.3657393456,\"37\":11.9280538559,\"38\":12.250169754,\"39\":12.250169754,\"40\":12.250169754,\"41\":12.250169754,\"42\":13.2190890312,\"43\":13.2190890312,\"44\":14.1028037071,\"45\":14.1028037071,\"46\":15.1560206413,\"47\":15.1560206413,\"48\":15.1560206413,\"49\":15.2686314583,\"50\":15.2686314583,\"51\":15.2686314583,\"52\":15.4389996529,\"53\":15.577214241,\"54\":15.9939594269,\"55\":16.6776275635,\"56\":16.6776275635,\"57\":16.6776275635,\"58\":17.6165761948,\"59\":18.4236717224,\"60\":19.0181713104,\"61\":19.0181713104,\"62\":19.0181713104,\"63\":19.2555265427,\"64\":19.2555265427,\"65\":19.4055643082,\"66\":19.4055643082,\"67\":19.4055643082,\"68\":19.4055643082,\"69\":19.4256258011,\"70\":19.4256258011,\"71\":19.4256258011,\"72\":19.7372360229,\"73\":19.7372360229,\"74\":19.8501338959,\"75\":19.8501338959,\"76\":20.2647123337,\"77\":20.2647123337,\"78\":21.1100416183,\"79\":21.7791109085,\"80\":21.9644813538,\"81\":22.6647596359,\"82\":22.6647596359,\"83\":23.0470628738,\"84\":23.0470628738,\"85\":23.0470628738,\"86\":23.4238758087,\"87\":23.4238758087,\"88\":23.4238758087,\"89\":24.1916942596,\"90\":24.1916942596,\"91\":24.1916942596,\"92\":24.1916942596,\"93\":24.1916942596,\"94\":24.844619751,\"95\":24.9136810303,\"96\":24.9136810303,\"97\":24.9136810303,\"98\":25.2412595749,\"99\":25.7766513824},\"mean\":{\"0\":0.0,\"1\":0.4809557682,\"2\":0.6751243883,\"3\":1.848341952,\"4\":2.2092813879,\"5\":2.2774870563,\"6\":2.6833242399,\"7\":3.0531603545,\"8\":3.0534714572,\"9\":3.4160163848,\"10\":4.1076178042,\"11\":5.1777149979,\"12\":5.1777149979,\"13\":5.8020842653,\"14\":5.9287117338,\"15\":6.7402318215,\"16\":7.462839248,\"17\":7.7123725579,\"18\":7.9590110133,\"19\":8.4970847772,\"20\":8.5193092303,\"21\":8.9388621041,\"22\":9.3587113205,\"23\":9.8707098334,\"24\":9.9953460901,\"25\":10.079441387,\"26\":10.3528867351,\"27\":10.637386037,\"28\":10.9466046154,\"29\":11.5567665016,\"30\":11.5866869461,\"31\":11.5866869461,\"32\":11.9989481562,\"33\":12.4527885144,\"34\":12.996428304,\"35\":13.0320507291,\"36\":13.3196417597,\"37\":13.4734232596,\"38\":13.6829085711,\"39\":13.8216553953,\"40\":13.8548669508,\"41\":14.0586935785,\"42\":14.360181196,\"43\":14.4745150974,\"44\":14.927148515,\"45\":14.927148515,\"46\":15.1169179165,\"47\":15.2393266118,\"48\":15.2922706807,\"49\":15.4999545372,\"50\":15.4999545372,\"51\":15.764865483,\"52\":15.9634421456,\"53\":16.2149620926,\"54\":16.3852402532,\"55\":16.5713604343,\"56\":16.9305286443,\"57\":17.1855359995,\"58\":17.6556331241,\"59\":18.1825532451,\"60\":18.3887750727,\"61\":18.6078670698,\"62\":18.6171949059,\"63\":19.0843626243,\"64\":19.2272744685,\"65\":19.4946573764,\"66\":19.5632806426,\"67\":19.6237084228,\"68\":19.6727945691,\"69\":19.9792077833,\"70\":20.0756787306,\"71\":20.1386352354,\"72\":20.4471281248,\"73\":20.6039339358,\"74\":20.9425478131,\"75\":20.9824826103,\"76\":21.3007528454,\"77\":21.3007528454,\"78\":21.4705757672,\"79\":21.5737115818,\"80\":21.7071152431,\"81\":22.0699624139,\"82\":22.0699624139,\"83\":22.3571001893,\"84\":22.4125080377,\"85\":22.4277465135,\"86\":22.8058339387,\"87\":22.9587328416,\"88\":22.9630507737,\"89\":23.0217729264,\"90\":23.2357786447,\"91\":23.2357786447,\"92\":23.37590514,\"93\":23.4468799287,\"94\":23.5532529908,\"95\":23.787753318,\"96\":23.8770457965,\"97\":23.9812222224,\"98\":24.1456217128,\"99\":24.6563612539},\"sem\":{\"0\":0.0,\"1\":0.2350176188,\"2\":0.2705456132,\"3\":0.5867841947,\"4\":0.6453639268,\"5\":0.6465400777,\"6\":0.6985499128,\"7\":0.7511959737,\"8\":0.7511832659,\"9\":0.8156681778,\"10\":0.8707433156,\"11\":0.9507313067,\"12\":0.9507313067,\"13\":1.0037635521,\"14\":1.0057082373,\"15\":1.0523201781,\"16\":1.0913270642,\"17\":1.1078124641,\"18\":1.1101364181,\"19\":1.1430263961,\"20\":1.1414734761,\"21\":1.1601569688,\"22\":1.1585564615,\"23\":1.1611935061,\"24\":1.1544948738,\"25\":1.155169348,\"26\":1.147509574,\"27\":1.15639663,\"28\":1.1714570796,\"29\":1.1790406478,\"30\":1.1806694388,\"31\":1.1806694388,\"32\":1.1805673234,\"33\":1.1823564188,\"34\":1.1607885545,\"35\":1.1637150311,\"36\":1.157457068,\"37\":1.149771944,\"38\":1.1605131778,\"39\":1.1610821265,\"40\":1.1616722201,\"41\":1.1794490674,\"42\":1.1854057295,\"43\":1.1783947681,\"44\":1.1818727492,\"45\":1.1818727492,\"46\":1.1841575854,\"47\":1.193782766,\"48\":1.1958262715,\"49\":1.1882193371,\"50\":1.1882193371,\"51\":1.1917391959,\"52\":1.1905367689,\"53\":1.205052649,\"54\":1.1951950536,\"55\":1.1877812792,\"56\":1.2181554958,\"57\":1.2091541794,\"58\":1.20639486,\"59\":1.2029719073,\"60\":1.1903821679,\"61\":1.1847994917,\"62\":1.1839059781,\"63\":1.1696652315,\"64\":1.1637875762,\"65\":1.1761448725,\"66\":1.1791067789,\"67\":1.1783770143,\"68\":1.1711200607,\"69\":1.1693434542,\"70\":1.1793891505,\"71\":1.1871211661,\"72\":1.1926296075,\"73\":1.1851748363,\"74\":1.1595276296,\"75\":1.1609291543,\"76\":1.1649542061,\"77\":1.1649542061,\"78\":1.1787852049,\"79\":1.1786548825,\"80\":1.1732043792,\"81\":1.1599793963,\"82\":1.1599793963,\"83\":1.1667090626,\"84\":1.1655393719,\"85\":1.1659871875,\"86\":1.1626443933,\"87\":1.1794021557,\"88\":1.1798744856,\"89\":1.1798963411,\"90\":1.1568491341,\"91\":1.1568491341,\"92\":1.1516823231,\"93\":1.1587149998,\"94\":1.1591145211,\"95\":1.1505725403,\"96\":1.1376033299,\"97\":1.1471910996,\"98\":1.1493338662,\"99\":1.1338794725}}"
},
"fit_time": [
2.853101348876953,
0.015745514644255686
],
"gen_time": [
0.3586791111854836,
0.0021974926058958215
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"__type": "AggregatedBenchmarkResult",
"name": "Hartmann_50d|SOBOL_BASELINE_1650404262",
"experiments": [],
"optimization_trace": {
"__type": "DataFrame",
"value": "{\"median\":{\"0\":-0.0948244818,\"1\":-0.224798277,\"2\":-0.3657992482,\"3\":-0.4513851404,\"4\":-0.543743223,\"5\":-0.6179356873,\"6\":-0.6783159971,\"7\":-0.799012512,\"8\":-0.8408489227,\"9\":-0.8874165714,\"10\":-0.9481673837,\"11\":-1.0616434216,\"12\":-1.092753768,\"13\":-1.1653189063,\"14\":-1.1832051873,\"15\":-1.1951954961,\"16\":-1.2114796042,\"17\":-1.2153908014,\"18\":-1.2583673596,\"19\":-1.319809854,\"20\":-1.319809854,\"21\":-1.319809854,\"22\":-1.3225582838,\"23\":-1.3328252435,\"24\":-1.3328252435,\"25\":-1.3395049572,\"26\":-1.3395049572,\"27\":-1.3777288198,\"28\":-1.4195615649,\"29\":-1.4195615649,\"30\":-1.4626145959,\"31\":-1.4733541012,\"32\":-1.5253723264,\"33\":-1.5253723264,\"34\":-1.5447837114,\"35\":-1.5652258992,\"36\":-1.5652258992,\"37\":-1.5652258992,\"38\":-1.5652258992,\"39\":-1.5862524509,\"40\":-1.6338999867,\"41\":-1.6747012138,\"42\":-1.6809063554,\"43\":-1.7065345645,\"44\":-1.7065345645,\"45\":-1.7091515064,\"46\":-1.7091515064,\"47\":-1.7091515064,\"48\":-1.7239334583,\"49\":-1.734405458,\"50\":-1.734405458,\"51\":-1.734405458,\"52\":-1.7411373258,\"53\":-1.7411373258,\"54\":-1.750967741,\"55\":-1.750967741,\"56\":-1.750967741,\"57\":-1.7575772405,\"58\":-1.7653059959,\"59\":-1.7910712957,\"60\":-1.7910712957,\"61\":-1.7910712957,\"62\":-1.7910712957,\"63\":-1.7910712957,\"64\":-1.8129526377,\"65\":-1.8129526377,\"66\":-1.8129526377,\"67\":-1.8129526377,\"68\":-1.8129526377,\"69\":-1.8510673642,\"70\":-1.8510673642,\"71\":-1.8751515746,\"72\":-1.8765948415,\"73\":-1.8893128037,\"74\":-1.8893128037,\"75\":-1.8893128037,\"76\":-1.8893128037,\"77\":-1.8893128037,\"78\":-1.9073534012,\"79\":-1.925014317,\"80\":-1.925014317,\"81\":-1.925014317,\"82\":-1.925014317,\"83\":-1.925014317,\"84\":-1.925014317,\"85\":-1.925014317,\"86\":-1.9274532795,\"87\":-1.9274532795,\"88\":-1.9274532795,\"89\":-1.9424396157,\"90\":-1.9424396157,\"91\":-1.9424396157,\"92\":-1.9424396157,\"93\":-1.9539708495,\"94\":-1.9977526665,\"95\":-1.9977526665,\"96\":-1.9977526665,\"97\":-1.9977526665,\"98\":-1.9977526665,\"99\":-1.9977526665},\"mean\":{\"0\":-0.272368158,\"1\":-0.3966269819,\"2\":-0.5171800084,\"3\":-0.5833398379,\"4\":-0.6481407857,\"5\":-0.7901653613,\"6\":-0.8645141309,\"7\":-0.9390212236,\"8\":-0.9987291093,\"9\":-1.0428132822,\"10\":-1.0862150642,\"11\":-1.1545382199,\"12\":-1.1913091607,\"13\":-1.2545111959,\"14\":-1.2688450493,\"15\":-1.2813848944,\"16\":-1.3181881675,\"17\":-1.3405613029,\"18\":-1.3770038182,\"19\":-1.411861971,\"20\":-1.43754208,\"21\":-1.4519230887,\"22\":-1.4555564973,\"23\":-1.4744651017,\"24\":-1.4826638392,\"25\":-1.5025079367,\"26\":-1.5108776727,\"27\":-1.5337431696,\"28\":-1.5409178725,\"29\":-1.5454227304,\"30\":-1.5691542339,\"31\":-1.5756913024,\"32\":-1.5959552234,\"33\":-1.6022921824,\"34\":-1.6141806293,\"35\":-1.6302795571,\"36\":-1.6302795571,\"37\":-1.6302795571,\"38\":-1.6356848007,\"39\":-1.6436587971,\"40\":-1.651992467,\"41\":-1.6632556111,\"42\":-1.6733855236,\"43\":-1.7161195838,\"44\":-1.718481046,\"45\":-1.7343091333,\"46\":-1.7392327559,\"47\":-1.7437787855,\"48\":-1.7600928843,\"49\":-1.7615198123,\"50\":-1.7618704641,\"51\":-1.7690711886,\"52\":-1.7862347168,\"53\":-1.7874842614,\"54\":-1.7960116714,\"55\":-1.7965994614,\"56\":-1.7970793885,\"57\":-1.8031001729,\"58\":-1.8079414731,\"59\":-1.8194116312,\"60\":-1.8240789086,\"61\":-1.8259234756,\"62\":-1.8311175746,\"63\":-1.8311175746,\"64\":-1.8398093349,\"65\":-1.8441612506,\"66\":-1.8441612506,\"67\":-1.8459215963,\"68\":-1.8523974073,\"69\":-1.866021421,\"70\":-1.8763734233,\"71\":-1.8845869994,\"72\":-1.9032210922,\"73\":-1.9121968853,\"74\":-1.9139404392,\"75\":-1.9162679577,\"76\":-1.9162679577,\"77\":-1.9162679577,\"78\":-1.9251958692,\"79\":-1.9349712253,\"80\":-1.9349712253,\"81\":-1.9384898674,\"82\":-1.9390314603,\"83\":-1.9390314603,\"84\":-1.9421616435,\"85\":-1.9452278733,\"86\":-1.9566326797,\"87\":-1.9566326797,\"88\":-1.9566326797,\"89\":-1.9667324841,\"90\":-1.9667324841,\"91\":-1.9667324841,\"92\":-1.9695701563,\"93\":-1.9842643917,\"94\":-1.9926702237,\"95\":-1.9982630932,\"96\":-2.0005003631,\"97\":-2.0046692979,\"98\":-2.0046692979,\"99\":-2.0046692979},\"sem\":{\"0\":0.0392350418,\"1\":0.0463405916,\"2\":0.0505655567,\"3\":0.0496661348,\"4\":0.0487556291,\"5\":0.0610454401,\"6\":0.0643405033,\"7\":0.06600557,\"8\":0.0660091919,\"9\":0.066303387,\"10\":0.0655631779,\"11\":0.065385547,\"12\":0.0649353751,\"13\":0.0630019364,\"14\":0.0626597467,\"15\":0.0616933255,\"16\":0.0604156268,\"17\":0.0586917203,\"18\":0.0609665061,\"19\":0.0607006169,\"20\":0.0605868144,\"21\":0.0609234371,\"22\":0.060702126,\"23\":0.0598068433,\"24\":0.0602129716,\"25\":0.0598368761,\"26\":0.0594063547,\"27\":0.0595261773,\"28\":0.0593712351,\"29\":0.0590011218,\"30\":0.0571160927,\"31\":0.0566771426,\"32\":0.0564494876,\"33\":0.055933424,\"34\":0.0554002141,\"35\":0.0549401001,\"36\":0.0549401001,\"37\":0.0549401001,\"38\":0.0546022039,\"39\":0.0541443816,\"40\":0.0535893635,\"41\":0.0527116964,\"42\":0.0519859412,\"43\":0.0513473586,\"44\":0.051046646,\"45\":0.0502619663,\"46\":0.0503249515,\"47\":0.0499270219,\"48\":0.050029136,\"49\":0.0498745682,\"50\":0.0498354562,\"51\":0.0488986846,\"52\":0.0496542154,\"53\":0.0495875417,\"54\":0.0494206223,\"55\":0.0494740817,\"56\":0.049426451,\"57\":0.0491406166,\"58\":0.0489446837,\"59\":0.0487674169,\"60\":0.0483755455,\"61\":0.048116258,\"62\":0.0485066305,\"63\":0.0485066305,\"64\":0.0486129575,\"65\":0.0478643938,\"66\":0.0478643938,\"67\":0.0476285385,\"68\":0.0473106301,\"69\":0.0472017085,\"70\":0.0478374951,\"71\":0.047898462,\"72\":0.047897976,\"73\":0.0475829253,\"74\":0.0474853914,\"75\":0.0474226868,\"76\":0.0474226868,\"77\":0.0474226868,\"78\":0.0469569236,\"79\":0.0461772951,\"80\":0.0461772951,\"81\":0.0458830587,\"82\":0.0458554032,\"83\":0.0458554032,\"84\":0.0461349907,\"85\":0.0461060523,\"86\":0.0452099352,\"87\":0.0452099352,\"88\":0.0452099352,\"89\":0.045205056,\"90\":0.045205056,\"91\":0.045205056,\"92\":0.0449499565,\"93\":0.0454858036,\"94\":0.0458645857,\"95\":0.0453835599,\"96\":0.0456151677,\"97\":0.0450484798,\"98\":0.0450484798,\"99\":0.0450484798}}"
},
"fit_time": [
3.2057241106033327,
0.01557995838753141
],
"gen_time": [
0.43513125223107635,
0.0018021831455838507
]
}
38 changes: 38 additions & 0 deletions ax/benchmark/problems/hd_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import asdict

from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.core.parameter import RangeParameter, ParameterType
from ax.core.search_space import SearchSpace


def embed_higher_dimension(
problem: BenchmarkProblem, total_dimensionality: int
) -> BenchmarkProblem:
num_dummy_dimensions = total_dimensionality - len(problem.search_space.parameters)

search_space = SearchSpace(
parameters=[
*problem.search_space.parameters.values(),
*[
RangeParameter(
name=f"embedding_dummy_{i}",
parameter_type=ParameterType.FLOAT,
lower=0,
upper=1,
)
for i in range(num_dummy_dimensions)
],
],
parameter_constraints=problem.search_space.parameter_constraints,
)

problem_kwargs = asdict(problem)
problem_kwargs["name"] = f"{problem_kwargs['name']}_{total_dimensionality}d"
problem_kwargs["search_space"] = search_space

return problem.__class__(**problem_kwargs)
23 changes: 22 additions & 1 deletion ax/benchmark/problems/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
SingleObjectiveBenchmarkProblem,
)
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult
from ax.benchmark.problems.hd_embedding import embed_higher_dimension
from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem
from ax.storage.json_store.decoder import object_from_json
from botorch.test_functions.multi_objective import BraninCurrin
from botorch.test_functions.synthetic import Branin, Ackley
from botorch.test_functions.synthetic import Hartmann, Branin, Ackley


@dataclass
Expand All @@ -43,6 +44,26 @@ class BenchmarkProblemRegistryEntry:
factory_kwargs={"test_problem": BraninCurrin()},
baseline_results_path="baseline_results/synthetic/branin_currin.json",
),
"branin_currin30": BenchmarkProblemRegistryEntry(
factory_fn=lambda n: embed_higher_dimension(
problem=MultiObjectiveBenchmarkProblem.from_botorch_multi_objective(
test_problem=BraninCurrin()
),
total_dimensionality=n,
),
factory_kwargs={"n": 30},
baseline_results_path="baseline_results/synthetic/hd/branin_currin_30d.json",
),
"hartmann50": BenchmarkProblemRegistryEntry(
factory_fn=lambda n: embed_higher_dimension(
problem=SingleObjectiveBenchmarkProblem.from_botorch_synthetic(
test_problem=Hartmann(dim=6)
),
total_dimensionality=n,
),
factory_kwargs={"n": 50},
baseline_results_path="baseline_results/synthetic/hd/hartmann_50d.json",
),
"hpo_pytorch_cnn_MNIST": BenchmarkProblemRegistryEntry(
factory_fn=PyTorchCNNTorchvisionBenchmarkProblem.from_dataset_name,
factory_kwargs={"name": "MNIST"},
Expand Down
9 changes: 8 additions & 1 deletion ax/runners/botorch_test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@ def run(self, trial: BaseTrial) -> Dict[str, Any]:
return {
"Ys": {
arm.name: self.test_problem.forward(
torch.tensor([value for _key, value in arm.parameters.items()])
torch.tensor(
[
value
for _key, value in [*arm.parameters.items()][
: self.test_problem.dim
]
]
)
).tolist()
for arm in trial.arms
},
Expand Down
8 changes: 8 additions & 0 deletions sphinx/source/benchmark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ Benchmark Problems Registry
:undoc-members:
:show-inheritance:

Benchmark Problems High Dimensional Embedding
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.benchmark.problems.hd_embedding
:members:
:undoc-members:
:show-inheritance:

Benchmark Problems PyTorchCNN
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down

0 comments on commit 751fa64

Please sign in to comment.