Skip to content

Commit

Permalink
Use value of enum instead of casting to string (facebook#2423)
Browse files Browse the repository at this point in the history
Summary:

`Keys.PAIRWISE_PREFERENCE_QUERY` or `str(Keys.PAIRWISE_PREFERENCE_QUERY)` is commonly used when a string is expected, whereas I'm suspecting the intention is `Keys.PAIRWISE_PREFERENCE_QUERY.value`. This is causing issues in D56634321 which (for now) fails if any metrics in Data.df are not present on experiment when calling observations_from_dataframe.

Reviewed By: Balandat

Differential Revision: D56850035
  • Loading branch information
bernardbeckerman authored and facebook-github-bot committed May 6, 2024
1 parent a59a36d commit 29262ac
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions ax/modelbridge/tests/test_pairwise_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def evaluate(
arm2_sum = float(sum(arm2_outcome_values))
is_arm1_preferred = int(arm1_sum - arm2_sum > 0)
return {
arm1: {Keys.PAIRWISE_PREFERENCE_QUERY: is_arm1_preferred},
arm2: {Keys.PAIRWISE_PREFERENCE_QUERY: 1 - is_arm1_preferred},
arm1: {Keys.PAIRWISE_PREFERENCE_QUERY.value: is_arm1_preferred},
arm2: {Keys.PAIRWISE_PREFERENCE_QUERY.value: 1 - is_arm1_preferred},
}

experiment = InstantiationBase.make_experiment(
Expand All @@ -70,7 +70,7 @@ def evaluate(
"bounds": [0.0, 0.7],
},
],
objectives={Keys.PAIRWISE_PREFERENCE_QUERY: "minimize"},
objectives={Keys.PAIRWISE_PREFERENCE_QUERY.value: "minimize"},
is_test=True,
)

Expand Down Expand Up @@ -145,12 +145,12 @@ def test_PairwiseModelBridge(self) -> None:

observation_data = [
ObservationData(
metric_names=[Keys.PAIRWISE_PREFERENCE_QUERY],
metric_names=[Keys.PAIRWISE_PREFERENCE_QUERY.value],
means=np.array([0]),
covariance=np.array([[np.nan]]),
),
ObservationData(
metric_names=[Keys.PAIRWISE_PREFERENCE_QUERY],
metric_names=[Keys.PAIRWISE_PREFERENCE_QUERY.value],
means=np.array([1]),
covariance=np.array([[np.nan]]),
),
Expand All @@ -168,7 +168,7 @@ def test_PairwiseModelBridge(self) -> None:
),
]
parameters = ["X1", "X2"]
outcomes = [checked_cast(str, Keys.PAIRWISE_PREFERENCE_QUERY)]
outcomes = [checked_cast(str, Keys.PAIRWISE_PREFERENCE_QUERY.value)]

datasets, _, candidate_metadata = pmb._convert_observations(
observation_data=observation_data,
Expand Down

0 comments on commit 29262ac

Please sign in to comment.