Skip to content

Commit

Permalink
Clarify that inter-point constraints are not supported by `get_polyto…
Browse files Browse the repository at this point in the history
…pe_samples`, raise informative error.

This addresses pytorch#2468
  • Loading branch information
Balandat committed Aug 14, 2024
1 parent f4c2915 commit 359e14a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
30 changes: 20 additions & 10 deletions botorch/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,15 +831,17 @@ def normalize_sparse_linear_constraints(
r"""Normalize sparse linear constraints to the unit cube.
Args:
bounds (Tensor): A `2 x d`-dim tensor containing the box bounds.
constraints (List[Tuple[Tensor, Tensor, float]]): A list of
tuples (indices, coefficients, rhs), with each tuple encoding
an inequality constraint of the form
bounds: A `2 x d`-dim tensor containing the box bounds.
constraints: A list of tuples (`indices`, `coefficients`, `rhs`), with
`indices` and `coefficients` one-dimensional tensors and `rhs` a
scalar, where each tuple encodes an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs` or
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
"""
new_constraints = []
for index, coefficient, rhs in constraints:
if index.ndim != 1:
raise ValueError("`indices` must be a one-dimensional tensor.")
lower, upper = bounds[:, index]
s = upper - lower
new_constraints.append(
Expand Down Expand Up @@ -894,14 +896,21 @@ def get_polytope_samples(
from the `Ax >= b` format expecxted here to the `Ax <= b` format expected by
`PolytopeSampler` by multiplying both `A` and `b` by -1.)
NOTE: This method does not support the kind of "inter-point constraints" that
are supported by `optimize_acqf()`. To achieve this behavior, you need define the
problem on the joint space over `q` points and impose use constraints, see:
https://github.com/pytorch/botorch/issues/2468#issuecomment-2287706461
Args:
n: The number of samples.
bounds: A `2 x d`-dim tensor containing the box bounds.
inequality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
inequality_constraints: A list of tuples (`indices`, `coefficients`, `rhs`),
with `indices` and `coefficients` one-dimensional tensors and `rhs` a
scalar, where each tuple encodes an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
equality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
equality_constraints: A list of tuples (`indices`, `coefficients`, `rhs`),
with `indices` and `coefficients` one-dimensional tensors and `rhs` a
scalar, where each tuple encodes an equality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
seed: The random seed.
n_burnin: The number of burn-in samples for the Markov chain sampler.
Expand Down Expand Up @@ -950,8 +959,9 @@ def sparse_to_dense_constraints(
Args:
d: The input dimension.
inequality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an (in)equality constraint of the form
constraints: A list of tuples (`indices`, `coefficients`, `rhs`),
with `indices` and `coefficients` one-dimensional tensors and `rhs` a
scalar, where each tuple encodes an (in)equality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs` or
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`.
Expand Down
7 changes: 7 additions & 0 deletions test/utils/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ def test_normalize_sparse_linear_constraints(self):
)
expected_rhs = 0.5
self.assertAlmostEqual(new_constraints[0][-1], expected_rhs)
with self.assertRaisesRegex(
ValueError, "`indices` must be a one-dimensional tensor."
):
normalize_sparse_linear_constraints(
bounds,
[(torch.tensor([[1, 2], [3, 4]]), torch.tensor([1.0, 1.0]), 1.0)],
)

def test_normalize_sparse_linear_constraints_wrong_dtype(self):
for dtype in (torch.float, torch.double):
Expand Down

0 comments on commit 359e14a

Please sign in to comment.