How to make qEHVI or qNEHVI work with MultitaskGP? #2420
Replies: 5 comments
-
Somebody might have a more expedient response but here is an example of
mobo with an mtgp using fast sampling.
It would be nice to have a (potentially bite sized) tutorial on the bototch
site.
https://github.com/wjmaddox/mtgp_sampler/blob/master/mtgp_experiments/constrained_mobo.py
…On Tue, Jul 9, 2024 at 4:24 AM sheikh ahnaf ***@***.***> wrote:
Hi,
Whenever I try to use qNEHVI with MultitaskGP it throws the following
error:
NotImplementedError Traceback (most recent call last)
Cell In[35], line 24
22 # not specify output_tasks will give an output for each task
23 model = MultiTaskGP(train_X, train_Y, task_feature=1)
---> 24 qNEHVI = qNoisyExpectedHypervolumeImprovement(
25 model=model,
26 ref_point=torch.zeros(2),
27 X_baseline=train_X,
28 )
29 test_X = torch.rand(64,8,2) # batch_shape=64 x q=8 x d=4
30 posterior = model.posterior(full_train_x,output_indices=full_train_i.type_as(full_train_x))
File /scratch/user/ahnafalvi/envs/gptor2/lib/python3.12/site-packages/botorch/acquisition/multi_objective/monte_carlo.py:420, in qNoisyExpectedHypervolumeImprovement.__init__(self, model, ref_point, X_baseline, sampler, objective, constraints, X_pending, eta, fat, prune_baseline, alpha, cache_pending, max_iep, incremental_nehvi, cache_root, marginalize_dim)
411 MultiObjectiveMCAcquisitionFunction.__init__(
412 self,
413 model=model,
(...)
417 eta=eta,
418 )
419 SubsetIndexCachingMixin.__init__(self)
--> 420 NoisyExpectedHypervolumeMixin.__init__(
421 self,
422 model=model,
423 ref_point=ref_point,
424 X_baseline=X_baseline,
425 sampler=self.sampler,
426 objective=self.objective,
427 constraints=self.constraints,
428 X_pending=X_pending,
429 prune_baseline=prune_baseline,
430 alpha=alpha,
431 cache_pending=cache_pending,
432 max_iep=max_iep,
433 incremental_nehvi=incremental_nehvi,
434 cache_root=cache_root,
435 marginalize_dim=marginalize_dim,
436 )
437 self.fat = fat
File /scratch/user/ahnafalvi/envs/gptor2/lib/python3.12/site-packages/botorch/utils/multi_objective/hypervolume.py:644, in NoisyExpectedHypervolumeMixin.__init__(self, model, ref_point, X_baseline, sampler, objective, constraints, X_pending, prune_baseline, alpha, cache_pending, max_iep, incremental_nehvi, cache_root, marginalize_dim)
639 # In the case that X_pending is not None, but there are fewer than
640 # max_iep pending points, the box decompositions are not performed in
641 # set_X_pending. Therefore, we need to perform a box decomposition over
642 # f(X_baseline) here.
643 if X_pending is None or X_pending.shape[-2] <= self._max_iep:
--> 644 self._set_cell_bounds(num_new_points=X_baseline.shape[0])
646 # Set q_in=-1 to so that self.sampler is updated at the next forward call.
647 self.q_in = -1
File /scratch/user/ahnafalvi/envs/gptor2/lib/python3.12/site-packages/botorch/utils/multi_objective/hypervolume.py:768, in NoisyExpectedHypervolumeMixin._set_cell_bounds(self, num_new_points)
759 else:
760 # use batched partitioning
761 obj = _pad_batch_pareto_frontier(
762 Y=obj,
763 ref_point=self.ref_point.unsqueeze(0).expand(
(...)
766 feasibility_mask=feas,
767 )
--> 768 self.partitioning = self.p_class(
769 ref_point=self.ref_point, Y=obj, **self.p_kwargs
770 )
771 cell_bounds = self.partitioning.get_hypercell_bounds().to(self.ref_point)
772 cell_bounds = cell_bounds.view(
773 2, *self._batch_sample_shape, *cell_bounds.shape[-2:]
774 ) # 2 x batch_shape x sample_shape x num_cells x m
File /scratch/user/ahnafalvi/envs/gptor2/lib/python3.12/site-packages/botorch/utils/multi_objective/box_decompositions/non_dominated.py:384, in FastNondominatedPartitioning.__init__(self, ref_point, Y)
370 def __init__(
371 self,
372 ref_point: Tensor,
373 Y: Optional[Tensor] = None,
374 ) -> None:
375 """Initialize FastNondominatedPartitioning.
376
377 Args:
(...)
382 >>> bd = FastNondominatedPartitioning(ref_point, Y=Y1)
383 """
--> 384 super().__init__(ref_point=ref_point, Y=Y)
File /scratch/user/ahnafalvi/envs/gptor2/lib/python3.12/site-packages/botorch/utils/multi_objective/box_decompositions/box_decomposition.py:275, in FastPartitioning.__init__(self, ref_point, Y)
265 def __init__(
266 self,
267 ref_point: Tensor,
268 Y: Optional[Tensor] = None,
269 ) -> None:
270 """
271 Args:
272 ref_point: A `m`-dim tensor containing the reference point.
273 Y: A `(batch_shape) x n x m`-dim tensor
274 """
--> 275 super().__init__(ref_point=ref_point, Y=Y, sort=ref_point.shape[-1] == 2)
File /scratch/user/ahnafalvi/envs/gptor2/lib/python3.12/site-packages/botorch/utils/multi_objective/box_decompositions/box_decomposition.py:68, in BoxDecomposition.__init__(self, ref_point, sort, Y)
66 self._validate_inputs()
67 self._neg_pareto_Y = self._compute_pareto_Y()
---> 68 self.partition_space()
69 else:
70 self._neg_Y = None
File /scratch/user/ahnafalvi/envs/gptor2/lib/python3.12/site-packages/botorch/utils/multi_objective/box_decompositions/box_decomposition.py:326, in FastPartitioning.partition_space(self)
324 self._get_single_cell()
325 else:
--> 326 super().partition_space()
File /scratch/user/ahnafalvi/envs/gptor2/lib/python3.12/site-packages/botorch/utils/multi_objective/box_decompositions/box_decomposition.py:152, in BoxDecomposition.partition_space(self)
150 self._partition_space()
151 else:
--> 152 self._partition_space()
File /scratch/user/ahnafalvi/envs/gptor2/lib/python3.12/site-packages/botorch/utils/multi_objective/box_decompositions/box_decomposition.py:337, in FastPartitioning._partition_space(self)
329 r"""Partition the non-dominated space into disjoint hypercells.
330
331 This method supports an arbitrary number of outcomes, but is
332 less efficient than `partition_space_2d` for the 2-outcome case.
333 """
334 if len(self.batch_shape) > 0:
335 # this could be triggered when m=2 outcomes and
336 # BoxDecomposition._partition_space_2d is not overridden.
--> 337 raise NotImplementedError(
338 "_partition_space does not support batch dimensions."
339 )
340 # this assumes minimization
341 # initialize local upper bounds
342 self.register_buffer("_U", self._neg_ref_point.unsqueeze(-2).clone())
NotImplementedError: _partition_space does not support batch dimensions.
Code example
import torch
import math
from botorch.models.multitask import MultiTaskGP
from botorch.acquisition.multi_objective.monte_carlo import qNoisyExpectedHypervolumeImprovement
torch.set_default_dtype(torch.float64)
train_x = torch.linspace(0, 0.95, 10) + 0.05 * torch.rand(10)
train_y1 = torch.sin(train_x * (2 * math.pi)) + torch.randn_like(train_x) * 0.2
train_y2 = torch.cos(train_x * (2 * math.pi)) + torch.randn_like(train_x) * 0.2
train_i_task1 = torch.full_like(train_x, dtype=torch.long, fill_value=0)
train_i_task2 = torch.full_like(train_x, dtype=torch.long, fill_value=1)
full_train_x = torch.cat([train_x, train_x])
full_train_i = torch.cat([train_i_task1, train_i_task2])
full_train_y = torch.cat([train_y1, train_y2])
train_X = torch.stack([full_train_x, full_train_i.type_as(full_train_x)], dim=-1)
train_Y = full_train_y.unsqueeze(-1)
print(train_Y.shape)
# not specify output_tasks will give an output for each task
model = MultiTaskGP(train_X, train_Y, task_feature=1)
qNEHVI = qNoisyExpectedHypervolumeImprovement(
model=model,
ref_point=torch.zeros(2),
X_baseline=train_X,
)
test_X = torch.rand(8, 2) # batch_shape=64 x q=8 x d=4
posterior = model.posterior(test_X)
samples = qNEHVI.get_posterior_samples(posterior) # default sampler has sample_shape=128. Samples has shape sample_shape=128 x batch_shape=64 x q=8 x m=2
acq_vals = qNEHVI(test_X) # has shape batch_shape=64
qNEHVI.sampler.base_samples.shape
Now, if I am not wrong , is it happening because of the task indices, all
the two objectives are being outputted in a one-dimensional array and the
acquisition function is thinking of it as being single objective because it
has only dimension? or is it something else? how to get around it to make
multitaskgp work with qNEHVI or qEHVI?
—
Reply to this email directly, view it on GitHub
<#2416>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAW34NNPR7MQB6UWSP2ZBDZLOM2TAVCNFSM6AAAAABKSLACKWVHI2DSMVQWIX3LMV43ASLTON2WKOZSGM4TONBZGYYTEMQ>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Beta Was this translation helpful? Give feedback.
-
Hi, Thank you so much for your response. It's quite insightful. But , I was trying find a use case where they use MultitaskGP having task indices and interface with qEHVI or qNEHVI. As far as I understood, I may be wrong, The repo stated above does not make use of task indices. The matheronmultitaskgp they are using, to my understanding is getting inheritence from kronickerMTGP. In your experience, have you seen any example of MTGP with task indices(hadamard) being used with qEHVI? is it possible to do so? or is that not implemented yet? |
Beta Was this translation helpful? Give feedback.
-
You would need to use a I tagged this as a bug because the error message should be more informative, and I agree we need better documentation here. |
Beta Was this translation helpful? Give feedback.
-
@esantorella ,just so that I am understanding it clearly, If I have a MultitaskGP with different input data for two tasks, I could in theory train and make predictions with it. later i can take mean and variances from that MTGP to create two seperate singletaskGP models, for each tasks with mean and variance(as known noise to the singletaskGP) from that MTGP predictions and package them in a modellistGP to use qEHVi. Is that a sound strategy? On a different note, as you mentioned to use acquisition function that query on a task-by-task basis, is there any acqfunc implemented like that? can you mention the name of it? |
Beta Was this translation helpful? Give feedback.
-
Yeah, see the tutorials on Multi-fidelity BO and Multi-fidelity BO with discrete fidelities. Those both use a somewhat complicated setup where you define the cost of evaluating each task and the acquisition function weighs that against the benefits. If you know which task you want to evaluate at, you can skip the cost function and use a simpler setup, as in the discussion here. |
Beta Was this translation helpful? Give feedback.
-
Hi,
Whenever I try to use qNEHVI with MultitaskGP it throws the following error:
Code example
Now, if I am not wrong , is it happening because of the task indices, all the two objectives are being outputted in a one-dimensional array and the acquisition function is thinking of it as being single objective because it has only dimension? or is it something else? how to get around it to make multitaskgp work with qNEHVI or qEHVI?
Beta Was this translation helpful? Give feedback.
All reactions