Skip to content

Commit

Permalink
Don't call dataset.X in input transform constructors (facebook#1993)
Browse files Browse the repository at this point in the history
Summary:

`dataset.X` will error out if using a `MultiTaskDataset` with heterogeneous feature sets. Updated the code to extract `d` from `feature_names` instead.

For Warp input constructor, removed the `batch_shape` argument. In Ax, we don't have batched inputs for models, so this would never get used.

Differential Revision: D51362512
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 16, 2023
1 parent 6a6c92d commit 5df9464
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.models.torch.botorch_defaults import _get_batch_shape
from ax.models.torch.utils import normalize_indices
from ax.utils.common.typeutils import _argparse_type_encoder
from botorch.models.transforms.input import (
Expand All @@ -20,7 +19,7 @@
Warp,
)
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher


Expand Down Expand Up @@ -75,22 +74,15 @@ def _input_transform_argparse_warp(
Returns:
A dictionary with input transform kwargs.
"""

input_transform_options = input_transform_options or {}
d = dataset.X.shape[-1]

d = len(dataset.feature_names)
indices = list(range(d))

task_features = normalize_indices(search_space_digest.task_features, d=d)

for task_feature in sorted(task_features, reverse=True):
del indices[task_feature]

batch_shape = _get_batch_shape(dataset.X, dataset.Y)

input_transform_options.setdefault("indices", indices)
input_transform_options.setdefault("batch_shape", batch_shape)

return input_transform_options


Expand Down Expand Up @@ -118,18 +110,15 @@ def _input_transform_argparse_normalize(
Returns:
A dictionary with input transform kwargs.
"""

input_transform_options = input_transform_options or {}

d = dataset.X.shape[-1]

d = input_transform_options.get("d", len(dataset.feature_names))
bounds = torch.as_tensor(
search_space_digest.bounds,
dtype=torch_dtype,
device=torch_device,
).T

if isinstance(dataset.X, SliceContainer):
if isinstance(dataset, RankingDataset) and isinstance(dataset.X, SliceContainer):
d = dataset.X.values.shape[-1]

indices = list(range(d))
Expand Down
33 changes: 26 additions & 7 deletions ax/models/torch/tests/test_input_transform_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Normalize,
Warp,
)
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset


class DummyInputTransform(InputTransform): # pyre-ignore [13]
Expand Down Expand Up @@ -92,7 +92,6 @@ def test_argparse_input_transform(self) -> None:
self.assertEqual(input_transform_kwargs, {"d": 10})

def test_argparse_normalize(self) -> None:

input_transform_kwargs = input_transform_argparse(
Normalize,
dataset=self.dataset,
Expand Down Expand Up @@ -137,8 +136,30 @@ def test_argparse_normalize(self) -> None:
)
)

def test_argparse_warp(self) -> None:
# Test with MultiTaskDataset.
dataset1 = SupervisedDataset(
X=torch.rand(5, 4),
Y=torch.randn(5, 1),
feature_names=[f"x{i}" for i in range(4)],
outcome_names=["y0"],
)

dataset2 = SupervisedDataset(
X=torch.rand(5, 2),
Y=torch.randn(5, 1),
feature_names=[f"x{i}" for i in range(2)],
outcome_names=["y1"],
)
mtds = MultiTaskDataset(datasets=[dataset1, dataset2], target_outcome_name="y0")
input_transform_kwargs = input_transform_argparse(
Normalize,
dataset=mtds,
search_space_digest=self.search_space_digest,
)
self.assertEqual(input_transform_kwargs["d"], 4)
self.assertEqual(input_transform_kwargs["indices"], [0, 1, 2])

def test_argparse_warp(self) -> None:
self.search_space_digest.task_features = [0, 3]
input_transform_kwargs = input_transform_argparse(
Warp,
Expand All @@ -148,7 +169,7 @@ def test_argparse_warp(self) -> None:

self.assertEqual(
input_transform_kwargs,
{"indices": [1, 2], "batch_shape": torch.Size([2])},
{"indices": [1, 2]},
)

input_transform_kwargs = input_transform_argparse(
Expand All @@ -158,9 +179,7 @@ def test_argparse_warp(self) -> None:
input_transform_options={"indices": [0, 1]},
)

self.assertEqual(
input_transform_kwargs, {"indices": [0, 1], "batch_shape": torch.Size([2])}
)
self.assertEqual(input_transform_kwargs, {"indices": [0, 1]})

def test_argparse_input_perturbation(self) -> None:

Expand Down

0 comments on commit 5df9464

Please sign in to comment.