diff --git a/ax/modelbridge/transforms/relativize.py b/ax/modelbridge/transforms/relativize.py index e9389afbf37..ec763a0d5a6 100644 --- a/ax/modelbridge/transforms/relativize.py +++ b/ax/modelbridge/transforms/relativize.py @@ -168,15 +168,10 @@ def _rel_op_on_observations( self.modelbridge.status_quo_data_by_trial, self.MISSING_STATUS_QUO_ERROR ) - missing_index = any(obs.features.trial_index is None for obs in observations) - default_trial_idx: Optional[int] = None - if missing_index: - if len(sq_data_by_trial) == 1: - default_trial_idx = next(iter(sq_data_by_trial)) - else: - raise ValueError( - "Observations contain missing trial index that can't be inferred." - ) + # use latest index of latest observed trial by default + # to handle pending trials, which may not have a trial_index + # if TrialAsTask was not used to generate the trial. + default_trial_idx: int = max(sq_data_by_trial.keys()) def _get_relative_data_from_obs( obs: Observation, diff --git a/ax/modelbridge/transforms/tests/test_relativize_transform.py b/ax/modelbridge/transforms/tests/test_relativize_transform.py index 1d65cf9cc20..21f3d888a5c 100644 --- a/ax/modelbridge/transforms/tests/test_relativize_transform.py +++ b/ax/modelbridge/transforms/tests/test_relativize_transform.py @@ -7,7 +7,7 @@ from copy import deepcopy from typing import List, Tuple -from unittest.mock import Mock, patch, PropertyMock +from unittest.mock import Mock import numpy as np from ax.core import BatchTrial @@ -159,21 +159,6 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data( observations[0].features.trial_index = 999 self.assertRaises(ValueError, tf.transform_observations, observations) - # When observation has missing trial_index and - # modelbridge.status_quo_data_by_trial has more than one trial, - # raise exception - observations[0].features.trial_index = None - with patch.object( - type(modelbridge), "status_quo_data_by_trial", new_callable=PropertyMock - ) as mock_sq_dict: - # Making modelbridge.status_quo_data_by_trial contains 2 trials - mock_sq_dict.return_value = {0: Mock(), 1: Mock()} - with self.assertRaisesRegex( - ValueError, - "Observations contain missing trial index that can't be inferred.", - ): - tf.transform_observations(observations) - def test_relativize_transform_observations(self) -> None: def _check_transform_observations( tf: Transform, @@ -257,8 +242,18 @@ def _check_transform_observations( observations=observations, expected_mean_and_covar=expected_mean_and_covar, ) - # transform should still work when trial_index is None and - # there is only one sq in modelbridge + # transform should still work when trial_index is None + modelbridge = Mock( + status_quo=Mock( + data=obs_data[0], features=obs_features[0], arm_name=arm_names[0] + ), + status_quo_data_by_trial={0: obs_data[1], 1: obs_data[0]}, + ) + tf = relativize_cls( + search_space=None, + observations=observations, + modelbridge=modelbridge, + ) for obs in observations: obs.features.trial_index = None _check_transform_observations(