Skip to content

Commit

Permalink
handle ObservationFeatures without trial_index in Relativize (#2441)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2441

This  adds support for untransforming observation features in Relativize that do not have a trial index. This can occur when the Modelbridge does not use TrialAsTask and there are multiple SQ observations (e.g. for different trials) and we try to untransform the predictions for arms in a GR.

Reviewed By: bernardbeckerman

Differential Revision: D57126153

fbshipit-source-id: 7e463bdaaad13850d19b7b6331fe1bd764d8e14e
  • Loading branch information
sdaulton authored and facebook-github-bot committed May 9, 2024
1 parent 8900752 commit 9a624a2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 27 deletions.
13 changes: 4 additions & 9 deletions ax/modelbridge/transforms/relativize.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,10 @@ def _rel_op_on_observations(
self.modelbridge.status_quo_data_by_trial, self.MISSING_STATUS_QUO_ERROR
)

missing_index = any(obs.features.trial_index is None for obs in observations)
default_trial_idx: Optional[int] = None
if missing_index:
if len(sq_data_by_trial) == 1:
default_trial_idx = next(iter(sq_data_by_trial))
else:
raise ValueError(
"Observations contain missing trial index that can't be inferred."
)
# use latest index of latest observed trial by default
# to handle pending trials, which may not have a trial_index
# if TrialAsTask was not used to generate the trial.
default_trial_idx: int = max(sq_data_by_trial.keys())

def _get_relative_data_from_obs(
obs: Observation,
Expand Down
31 changes: 13 additions & 18 deletions ax/modelbridge/transforms/tests/test_relativize_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from copy import deepcopy
from typing import List, Tuple
from unittest.mock import Mock, patch, PropertyMock
from unittest.mock import Mock

import numpy as np
from ax.core import BatchTrial
Expand Down Expand Up @@ -159,21 +159,6 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data(
observations[0].features.trial_index = 999
self.assertRaises(ValueError, tf.transform_observations, observations)

# When observation has missing trial_index and
# modelbridge.status_quo_data_by_trial has more than one trial,
# raise exception
observations[0].features.trial_index = None
with patch.object(
type(modelbridge), "status_quo_data_by_trial", new_callable=PropertyMock
) as mock_sq_dict:
# Making modelbridge.status_quo_data_by_trial contains 2 trials
mock_sq_dict.return_value = {0: Mock(), 1: Mock()}
with self.assertRaisesRegex(
ValueError,
"Observations contain missing trial index that can't be inferred.",
):
tf.transform_observations(observations)

def test_relativize_transform_observations(self) -> None:
def _check_transform_observations(
tf: Transform,
Expand Down Expand Up @@ -257,8 +242,18 @@ def _check_transform_observations(
observations=observations,
expected_mean_and_covar=expected_mean_and_covar,
)
# transform should still work when trial_index is None and
# there is only one sq in modelbridge
# transform should still work when trial_index is None
modelbridge = Mock(
status_quo=Mock(
data=obs_data[0], features=obs_features[0], arm_name=arm_names[0]
),
status_quo_data_by_trial={0: obs_data[1], 1: obs_data[0]},
)
tf = relativize_cls(
search_space=None,
observations=observations,
modelbridge=modelbridge,
)
for obs in observations:
obs.features.trial_index = None
_check_transform_observations(
Expand Down

0 comments on commit 9a624a2

Please sign in to comment.