Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auxiliary_experiments to Experiment #2634

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions ax/core/auxiliary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from enum import Enum, unique
from typing import Optional, TYPE_CHECKING

from ax.core.data import Data
from ax.utils.common.base import SortableBase


if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import core # noqa F401


class AuxiliaryExperiment(SortableBase):
"""Class for defining an auxiliary experiment."""

def __init__(
self,
experiment: core.experiment.Experiment,
data: Optional[Data] = None,
) -> None:
"""
Lightweight container of an experiment, and its data,
that will be used as auxiliary information for another experiment.
"""
self.experiment = experiment
self.data: Data = data or experiment.lookup_data()

def _unique_id(self) -> str:
# While there can be multiple `AuxiliarySource`-s made from the same
# experiment (and thus sharing the experiment name), the uniqueness
# here is only needed w.r.t. parent object ("main experiment", for which
# this will be an auxiliary source for).
return self.experiment.name


@unique
class AuxiliaryExperimentPurpose(Enum):
pass
15 changes: 13 additions & 2 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
from collections.abc import Hashable, Iterable, Mapping
from datetime import datetime
from functools import partial, reduce

from typing import Any, Optional

import ax.core.observation as observation
import pandas as pd
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.base_trial import BaseTrial, DEFAULT_STATUSES_TO_WARM_START, TrialStatus
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.data import Data
Expand Down Expand Up @@ -79,6 +81,9 @@ def __init__(
experiment_type: Optional[str] = None,
properties: Optional[dict[str, Any]] = None,
default_data_type: Optional[DataType] = None,
auxiliary_experiments_by_purpose: Optional[
dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]
] = None,
) -> None:
"""Inits Experiment.

Expand All @@ -94,6 +99,8 @@ def __init__(
experiment_type: The class of experiments this one belongs to.
properties: Dictionary of this experiment's properties.
default_data_type: Enum representing the data type this experiment uses.
auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments
for different purposes (e.g., transfer learning).
"""
# appease pyre
self._search_space: SearchSpace
Expand Down Expand Up @@ -127,6 +134,10 @@ def __init__(
self._arms_by_signature: dict[str, Arm] = {}
self._arms_by_name: dict[str, Arm] = {}

self.auxiliary_experiments_by_purpose: dict[
AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]
] = (auxiliary_experiments_by_purpose or {})

self.add_tracking_metrics(tracking_metrics or [])

# call setters defined below
Expand Down Expand Up @@ -1020,14 +1031,14 @@ def trials_by_status(self) -> dict[TrialStatus, list[BaseTrial]]:

@property
def trials_expecting_data(self) -> list[BaseTrial]:
"""List[BaseTrial]: the list of all trials for which data has arrived
"""list[BaseTrial]: the list of all trials for which data has arrived
or is expected to arrive.
"""
return [trial for trial in self.trials.values() if trial.status.expecting_data]

@property
def completed_trials(self) -> list[BaseTrial]:
"""List[BaseTrial]: the list of all trials for which data has arrived
"""list[BaseTrial]: the list of all trials for which data has arrived
or is expected to arrive.
"""
return self.trials_by_status[TrialStatus.COMPLETED]
Expand Down
26 changes: 26 additions & 0 deletions ax/core/tests/test_auxiliary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ax.core.auxiliary import AuxiliaryExperiment
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment, get_experiment_with_data


class AuxiliaryExperimentTest(TestCase):
def test_AuxiliaryExperiment(self) -> None:
for get_exp_func in [get_experiment, get_experiment_with_data]:
exp = get_exp_func()
data = exp.lookup_data()

# Test init
aux_exp = AuxiliaryExperiment(experiment=exp)
self.assertEqual(aux_exp.experiment, exp)
self.assertEqual(aux_exp.data, data)

another_aux_exp = AuxiliaryExperiment(
experiment=exp, data=exp.lookup_data()
)
self.assertEqual(another_aux_exp.experiment, exp)
self.assertEqual(another_aux_exp.data, data)
59 changes: 57 additions & 2 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

import logging
from collections import OrderedDict
from enum import unique
from unittest.mock import MagicMock, patch

import pandas as pd
from ax.core import BatchTrial, Trial
from ax.core import BatchTrial, Experiment, Trial
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric
Expand Down Expand Up @@ -53,6 +54,7 @@
get_branin_search_space,
get_data,
get_experiment,
get_experiment_with_data,
get_experiment_with_map_data_type,
get_optimization_config,
get_scalarized_outcome_constraint,
Expand Down Expand Up @@ -1471,3 +1473,56 @@ def test_it_does_not_take_both_single_and_multiple_gr_ars(self) -> None:
generator_run=gr1,
generator_runs=[gr2],
)

def test_experiment_with_aux_experiments(self) -> None:
@unique
class TestAuxiliaryExperimentPurpose(AuxiliaryExperimentPurpose):
MyAuxExpPurpose = "my_auxiliary_experiment_purpose"
MyOtherAuxExpPurpose = "my_other_auxiliary_experiment_purpose"

for get_exp_func in [get_experiment, get_experiment_with_data]:
exp = get_exp_func()
data = exp.lookup_data()

aux_exp = AuxiliaryExperiment(experiment=exp)
another_aux_exp = AuxiliaryExperiment(experiment=exp, data=data)

# init experiment with auxiliary experiments
exp_w_aux_exp = Experiment(
name="test",
search_space=get_search_space(),
auxiliary_experiments_by_purpose={
TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp],
},
)

# in-place modification of auxiliary experiments
exp_w_aux_exp.auxiliary_experiments_by_purpose[
TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose
] = [aux_exp]
self.assertEqual(
exp_w_aux_exp.auxiliary_experiments_by_purpose,
{
TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp],
TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [aux_exp],
},
)

# test setter
exp_w_aux_exp.auxiliary_experiments_by_purpose = {
TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp],
TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [
aux_exp,
another_aux_exp,
],
}
self.assertEqual(
exp_w_aux_exp.auxiliary_experiments_by_purpose,
{
TestAuxiliaryExperimentPurpose.MyAuxExpPurpose: [aux_exp],
TestAuxiliaryExperimentPurpose.MyOtherAuxExpPurpose: [
aux_exp,
another_aux_exp,
],
},
)
12 changes: 10 additions & 2 deletions ax/service/utils/instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Optional, Union

from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.experiment import DataType, Experiment
from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective
Expand Down Expand Up @@ -784,6 +785,9 @@ def make_experiment(
objective_thresholds: Optional[list[str]] = None,
support_intermediate_data: bool = False,
immutable_search_space_and_opt_config: bool = True,
auxiliary_experiments_by_purpose: Optional[
dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]
] = None,
is_test: bool = False,
) -> Experiment:
"""Instantiation wrapper that allows for Ax `Experiment` creation
Expand Down Expand Up @@ -823,6 +827,8 @@ def make_experiment(
a product in which it is used), if any.
tracking_metric_names: Names of additional tracking metrics not used for
optimization.
metric_definitions: A mapping of metric names to extra kwargs to pass
to that metric
objectives: Mapping from an objective name to "minimize" or "maximize"
representing the direction for that objective.
objective_thresholds: A list of objective threshold constraints for multi-
Expand All @@ -835,10 +841,11 @@ def make_experiment(
Defaults to True. If set to True, we won't store or load copies of the
search space and optimization config on each generator run, which will
improve storage performance.
auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for
different use cases (e.g., transfer learning).
is_test: Whether this experiment will be a test experiment (useful for
marking test experiments in storage etc). Defaults to False.
metric_definitions: A mapping of metric names to extra kwargs to pass
to that metric

"""
status_quo_arm = None if status_quo is None else Arm(parameters=status_quo)

Expand Down Expand Up @@ -889,6 +896,7 @@ def make_experiment(
tracking_metrics=tracking_metrics,
default_data_type=default_data_type,
properties=properties,
auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose,
is_test=is_test,
)

Expand Down
14 changes: 12 additions & 2 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from typing import Any

from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem
from ax.core import ObservationFeatures
from ax.core import Experiment, ObservationFeatures
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData, MapKeyInfo
from ax.core.metric import Metric
Expand Down Expand Up @@ -710,6 +710,16 @@ def risk_measure_to_dict(
}


def auxiliary_experiment_to_dict(
auxiliary_experiment: AuxiliaryExperiment,
) -> dict[str, Any]:
return {
"__type": auxiliary_experiment.__class__.__name__,
"experiment": auxiliary_experiment.experiment,
"data": auxiliary_experiment.data,
}


def pathlib_to_dict(path: Path) -> dict[str, Any]:
return {"__type": path.__class__.__name__, "pathsegments": [str(path)]}

Expand Down
8 changes: 6 additions & 2 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
)
from ax.benchmark.runners.botorch_test import BotorchTestProblemRunner
from ax.benchmark.runners.surrogate import SurrogateRunner
from ax.core import ObservationFeatures
from ax.core import Experiment, ObservationFeatures
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import (
AbandonedArm,
Expand All @@ -35,7 +36,7 @@
LifecycleStage,
)
from ax.core.data import Data
from ax.core.experiment import DataType, Experiment
from ax.core.experiment import DataType
from ax.core.generator_run import GeneratorRun
from ax.core.map_data import MapData, MapKeyInfo
from ax.core.map_metric import MapMetric
Expand Down Expand Up @@ -116,6 +117,7 @@
)
from ax.storage.json_store.encoders import (
arm_to_dict,
auxiliary_experiment_to_dict,
batch_to_dict,
best_model_selector_to_dict,
botorch_component_to_dict,
Expand Down Expand Up @@ -181,6 +183,7 @@
# avoid runtime subscripting errors.
CORE_ENCODER_REGISTRY: dict[type, Callable[[Any], dict[str, Any]]] = {
Arm: arm_to_dict,
AuxiliaryExperiment: auxiliary_experiment_to_dict,
AndEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict,
AugmentedBraninMetric: metric_to_dict,
AugmentedHartmann6Metric: metric_to_dict,
Expand Down Expand Up @@ -293,6 +296,7 @@
"AugmentedBraninMetric": AugmentedBraninMetric,
"AugmentedHartmann6Metric": AugmentedHartmann6Metric,
"AutoTransitionAfterGen": AutoTransitionAfterGen,
"AuxiliaryExperiment": AuxiliaryExperiment,
"Arm": Arm,
"AggregatedBenchmarkResult": AggregatedBenchmarkResult,
"BatchTrial": BatchTrial,
Expand Down
2 changes: 2 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
get_arm,
get_augmented_branin_metric,
get_augmented_hartmann_metric,
get_auxiliary_experiment,
get_batch_trial,
get_botorch_model,
get_botorch_model_with_default_acquisition_class,
Expand Down Expand Up @@ -141,6 +142,7 @@
("Arm", get_arm),
("AugmentedBraninMetric", get_augmented_branin_metric),
("AugmentedHartmannMetric", get_augmented_hartmann_metric),
("AuxiliaryExperiment", get_auxiliary_experiment),
("BatchTrial", get_batch_trial),
("BenchmarkMethod", get_sobol_gpei_benchmark_method),
("BenchmarkProblem", get_single_objective_benchmark_problem),
Expand Down
5 changes: 5 additions & 0 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pandas as pd
import torch
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import AbandonedArm, BatchTrial
from ax.core.data import Data
Expand Down Expand Up @@ -879,6 +880,10 @@ def get_high_dimensional_branin_experiment(with_batch: bool = False) -> Experime
return exp


def get_auxiliary_experiment() -> AuxiliaryExperiment:
return AuxiliaryExperiment(experiment=get_experiment_with_data())


##############################
# Search Spaces
##############################
Expand Down
Loading