Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Unable to fit Heteroskedastic GP with input warping #2551

Closed
SaiAakash opened this issue Sep 24, 2024 · 5 comments
Closed

[Bug] Unable to fit Heteroskedastic GP with input warping #2551

SaiAakash opened this issue Sep 24, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@SaiAakash
Copy link
Contributor

SaiAakash commented Sep 24, 2024

🐛 Bug

Unable to fit a Heteroskedastic GP with Warp input transform.

To reproduce

** Code snippet to reproduce **

# Define some training data as a DataFrame
import numpy as np
import torch
from botorch.models import HeteroskedasticSingleTaskGP
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms import (
    Normalize,
    Standardize,
    Warp,
    ChainedInputTransform
)
from gpytorch.mlls import ExactMarginalLogLikelihood


# The true function
def oscillator(x):
    return np.cos((x - 5) / 2) ** 2 * x * 2


noise_scale = 3.0

n_data = 200
X_data = np.random.uniform(-10, 10, n_data)
y_data = oscillator(X_data) + np.random.normal(scale=3.0, size=X_data.shape)

# add noise to data
y_noise = np.random.normal(scale=noise_scale, size=X_data.shape[0]) * np.abs(
    X_data * 0.5
)
y_data_heteroskedastic = oscillator(X_data) + y_noise

train_X = torch.tensor(X_data).view(-1, 1)
train_y = torch.tensor(y_data_heteroskedastic).view(-1, 1)
train_yvar = torch.tensor(y_noise**2).view(-1, 1)

n = 50
normalize = Normalize(d=1)
warp = Warp(indices=list(range(train_X.shape[1])))
outcome_transform = Standardize(m=1)
input_transform = ChainedInputTransform(**{"normalize": normalize, "warp": warp})


gp = HeteroskedasticSingleTaskGP(
    train_X=train_X[0:n],
    train_Y=train_y[0:n],
    train_Yvar=train_yvar[0:n],
    input_transform=input_transform,
    outcome_transform=outcome_transform,
)

mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)

** Stack trace/error message **

{
	"name": "RuntimeError",
	"message": "You must train on the training inputs!",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 15
      6 gp = HeteroskedasticSingleTaskGP(
      7     train_X=train_X[0:n],
      8     train_Y=train_y[0:n],
   (...)
     11     outcome_transform=outcome_transform,
     12 )
     14 mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
---> 15 fit_gpytorch_mll(mll)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs[\"optimizer\"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/fit.py:205, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, pick_best_of_all_attempts, warning_handler, caught_exception_types, **ignore)
    203 with catch_warnings(record=True) as warning_list, debug(True):
    204     simplefilter(\"always\", category=OptimizationWarning)
--> 205     result = optimizer(mll, closure=closure, **optimizer_kwargs)
    207 # Resolve warnings and determine whether or not to retry
    208 success = True

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/fit.py:94, in fit_gpytorch_mll_scipy(mll, parameters, bounds, closure, closure_kwargs, method, options, callback, timeout_sec)
     91 if closure_kwargs is not None:
     92     closure = partial(closure, **closure_kwargs)
---> 94 result = scipy_minimize(
     95     closure=closure,
     96     parameters=parameters,
     97     bounds=bounds,
     98     method=method,
     99     options=options,
    100     callback=callback,
    101     timeout_sec=timeout_sec,
    102 )
    103 if result.status != OptimizationStatus.SUCCESS:
    104     warn(
    105         f\"`scipy_minimize` terminated with status {result.status}, displaying\"
    106         f\" original message from `scipy.optimize.minimize`: {result.message}\",
    107         OptimizationWarning,
    108     )

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/core.py:110, in scipy_minimize(closure, parameters, bounds, callback, x0, method, options, timeout_sec)
    102         result = OptimizationResult(
    103             step=next(call_counter),
    104             fval=float(wrapped_closure(x)[0]),
    105             status=OptimizationStatus.RUNNING,
    106             runtime=monotonic() - start_time,
    107         )
    108         return callback(parameters, result)  # pyre-ignore [29]
--> 110 raw = minimize_with_timeout(
    111     wrapped_closure,
    112     wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
    113     jac=True,
    114     bounds=bounds_np,
    115     method=method,
    116     options=options,
    117     callback=wrapped_callback,
    118     timeout_sec=timeout_sec,
    119 )
    121 # Post-processing and outcome handling
    122 wrapped_closure.state = asarray(raw.x)  # set parameter state to optimal values

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/utils/timeout.py:83, in minimize_with_timeout(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options, timeout_sec)
     81 try:
     82     warnings.filterwarnings(\"error\", message=\"Method .* cannot handle\")
---> 83     return optimize.minimize(
     84         fun=fun,
     85         x0=x0,
     86         args=args,
     87         method=method,
     88         jac=jac,
     89         hess=hess,
     90         hessp=hessp,
     91         bounds=bounds,
     92         constraints=constraints,
     93         tol=tol,
     94         callback=wrapped_callback,
     95         options=options,
     96     )
     97 except OptimizationTimeoutError as e:
     98     msg = f\"Optimization timed out after {e.runtime} seconds.\"

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_minimize.py:731, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    728     res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
    729                              **options)
    730 elif meth == 'l-bfgs-b':
--> 731     res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
    732                            callback=callback, **options)
    733 elif meth == 'tnc':
    734     res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
    735                         **options)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_lbfgsb_py.py:407, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
    401 task_str = task.tobytes()
    402 if task_str.startswith(b'FG'):
    403     # The minimization routine wants f and g at the current x.
    404     # Note that interruptions due to maxfun are postponed
    405     # until the completion of the current minimization iteration.
    406     # Overwrite f and g:
--> 407     f, g = func_and_grad(x)
    408 elif task_str.startswith(b'NEW_X'):
    409     # new iteration
    410     n_iterations += 1

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:343, in ScalarFunction.fun_and_grad(self, x)
    341 if not np.array_equal(x, self.x):
    342     self._update_x(x)
--> 343 self._update_fun()
    344 self._update_grad()
    345 return self.f, self.g

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:294, in ScalarFunction._update_fun(self)
    292 def _update_fun(self):
    293     if not self.f_updated:
--> 294         fx = self._wrapped_fun(self.x)
    295         if fx < self._lowest_f:
    296             self._lowest_x = self.x

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:20, in _wrapper_fun.<locals>.wrapped(x)
     16 ncalls[0] += 1
     17 # Send a copy because the user may overwrite it.
     18 # Overwriting results in undefined behaviour because
     19 # fun(self.x) will change self.x, with the two no longer linked.
---> 20 fx = fun(np.copy(x), *args)
     21 # Make sure the function returns a true scalar
     22 if not np.isscalar(fx):

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_optimize.py:79, in MemoizeJac.__call__(self, x, *args)
     77 def __call__(self, x, *args):
     78     \"\"\" returns the function value \"\"\"
---> 79     self._compute_if_needed(x, *args)
     80     return self._value

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_optimize.py:73, in MemoizeJac._compute_if_needed(self, x, *args)
     71 if not np.all(x == self.x) or self._value is None or self.jac is None:
     72     self.x = np.asarray(x).copy()
---> 73     fg = self.fun(x, *args)
     74     self.jac = fg[1]
     75     self._value = fg[0]

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/core.py:162, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    160         index += size
    161 except RuntimeError as e:
--> 162     value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
    164 return value, grads

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/utils/common.py:32, in _handle_numerical_errors(error, x, dtype)
     30     _dtype = x.dtype if dtype is None else dtype
     31     return np.full((), \"nan\", dtype=_dtype), np.full_like(x, \"nan\", dtype=_dtype)
---> 32 raise error

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/core.py:152, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    149     self.state = state
    151 try:
--> 152     value_tensor, grad_tensors = self.closure(**kwargs)
    153     value = self.as_array(value_tensor)
    154     grads = self._get_gradient_ndarray(fill_value=self.fill_value)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
     64 def __call__(self, **kwargs: Any) -> tuple[Tensor, tuple[Optional[Tensor], ...]]:
     65     with self.context_manager():
---> 66         values = self.forward(**kwargs)
     67         value = values if self.reducer is None else self.reducer(values)
     68         self.backward(value)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/model_closures.py:179, in _get_loss_closure_exact_internal.<locals>.closure(**kwargs)
    177 # The inputs will get transformed in forward here.
    178 model_output = model(*model.train_inputs)
--> 179 log_likelihood = mll(
    180     model_output,
    181     model.train_targets,
    182     # During model training, the model inputs get transformed in the forward
    183     # pass. The train_inputs property is not transformed yet, so we need to
    184     # transform it before passing it to the likelihood for consistency.
    185     *(model.transform_inputs(X=t_in) for t_in in model.train_inputs),
    186     **kwargs,
    187 )
    188 return -log_likelihood

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/module.py:31, in Module.__call__(self, *inputs, **kwargs)
     30 def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
---> 31     outputs = self.forward(*inputs, **kwargs)
     32     if isinstance(outputs, list):
     33         return [_validate_module_outputs(output) for output in outputs]

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:83, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params, **kwargs)
     81 # Get the log prob of the marginal distribution
     82 res = output.log_prob(target)
---> 83 res = self._add_other_terms(res, params)
     85 # Scale by the amount of data we have
     86 num_data = function_dist.event_shape.numel()

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:42, in ExactMarginalLogLikelihood._add_other_terms(self, res, params)
     39 def _add_other_terms(self, res, params):
     40     # Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models)
     41     for added_loss_term in self.model.added_loss_terms():
---> 42         res = res.add(added_loss_term.loss(*params))
     44     # Add log probs of priors on the (functions of) parameters
     45     res_ndim = res.ndim

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/mlls/noise_model_added_loss_term.py:13, in NoiseModelAddedLossTerm.loss(self, *params)
     12 def loss(self, *params):
---> 13     output = self.noise_mll.model(*params)
     14     targets = self.noise_mll.model.train_targets
     15     return self.noise_mll(output, targets)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_gp.py:267, in ExactGP.__call__(self, *args, **kwargs)
    263 if settings.debug.on():
    264     if not all(
    265         torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)
    266     ):
--> 267         raise RuntimeError(\"You must train on the training inputs!\")
    268 res = super().__call__(*inputs, **kwargs)
    269 return res

RuntimeError: You must train on the training inputs!"
}

Expected Behavior

Model should fit to the data without raising an error.

System information

Please complete the following information:

  • BoTorch version: 0.12.0
  • GPyTorch version: 1.13
  • PyTorch version: 2.4.1
  • OS: macOS Sonoma 14.5

Additional context

I suspect this originates from a recent change to HeteroskedasticSingleTaskGP from #2527. This might be because Warp works in a slightly different way than other input transforms like Normalize. The relevant attributes for Normalize are computed at initialisation and is fixed to the same value throughout training whereas for learnable transforms like warp (where the parameters of the input transform are considered somewhat like hyper parameters of the GP) the transform dynamically changes with training iterations.

The input transform is applied on the training_inputs in the mll closure everytime the mll is computed. This can lead to the divergence of training_inputs from their values during initialisation since the warp parameters are updated and the train_inputs are now constantly evolving during training with the new changes in #2527.

I was able to hypothesize this because the mll is successfully computed once when I call fit_gpytorch_mll and it only fails the second time which I think is due to training_inputs diverging from its original value because warp parameters are updated after the first iteration.

@SaiAakash SaiAakash added the bug Something isn't working label Sep 24, 2024
@SaiAakash SaiAakash changed the title [Bug] Unable to fit Heteroskedastic GP when warping inputs [Bug] Unable to fit Heteroskedastic GP with input warping Sep 24, 2024
@Balandat
Copy link
Contributor

Thanks for flagging this, your explanation makes sense to me. Were you able to verify that this worked properly before the change in #2527?

@SaiAakash
Copy link
Contributor Author

@Balandat yes. I was able to fit the model without any problems after undoing the changes from #2527.

@Balandat
Copy link
Contributor

cc @saitcakmak - I wonder if we need to distinguish the "one-shot learnable" input transforms and those that require joint optimization with the model parameters for settings like this.

At a higher level, it's challenging to deal with transforms reliably and robustly in the HeteroskedasticSingleTaskGP setup as it wraps a model within another model so we have to deal with transforms in different places. Maybe one potential approach might be to apply the reverse of the input warping (assuming it's invertible) within the "outer" GP in the HeteroskedasticSingleTaskGP and then use the same input transform on the "inner" GP?

@saitcakmak
Copy link
Contributor

When it comes to Heteroskedastic GP, my vote is to remove it from BoTorch. Even if we fix this particular issue, there is still the longstanding issue of its noise model not working correctly: #861

@slishak-PX
Copy link
Contributor

slishak-PX commented Sep 27, 2024

When it comes to Heteroskedastic GP, my vote is to remove it from BoTorch. Even if we fix this particular issue, there is still the longstanding issue of its noise model not working correctly: #861

I'd agree with this, or at the very least either raise a warning to the user, or add a warning to the docstring explaining this.

I was going to create another issue to point out that as of 0.12.0 all of the models are supposed to Standardize the outcomes by default, but HeteroskedasticSingleTaskGP does not. Given the discussion above I thought it better to include here (although I imagine this omission was intentional).

saitcakmak added a commit to saitcakmak/botorch that referenced this issue Nov 6, 2024
Summary:
This model has been buggy for quite a long time and we still haven't fixed it. Removing it should be preferable to keeping around a known buggy model. Example bug reports:
- pytorch#861
- pytorch#933
- pytorch#2551

Differential Revision: D65543676
saitcakmak added a commit to saitcakmak/Ax that referenced this issue Nov 6, 2024
Summary:
X-link: pytorch/botorch#2616

This model has been buggy for quite a long time and we still haven't fixed it. Removing it should be preferable to keeping around a known buggy model. Example bug reports:
- pytorch/botorch#861
- pytorch/botorch#933
- pytorch/botorch#2551

Differential Revision: D65543676
saitcakmak added a commit to saitcakmak/botorch that referenced this issue Nov 6, 2024
Summary:

This model has been buggy for quite a long time and we still haven't fixed it. Removing it should be preferable to keeping around a known buggy model. Example bug reports:
- pytorch#861
- pytorch#933
- pytorch#2551

Differential Revision: D65543676
saitcakmak added a commit to saitcakmak/botorch that referenced this issue Nov 7, 2024
Summary:

This model has been buggy for quite a long time and we still haven't fixed it. Removing it should be preferable to keeping around a known buggy model. Example bug reports:
- pytorch#861
- pytorch#933
- pytorch#2551

Reviewed By: esantorella

Differential Revision: D65543676
facebook-github-bot pushed a commit that referenced this issue Nov 7, 2024
Summary:
X-link: facebook/Ax#3026

Pull Request resolved: #2616

This model has been buggy for quite a long time and we still haven't fixed it. Removing it should be preferable to keeping around a known buggy model. Example bug reports:
- #861
- #933
- #2551

Reviewed By: esantorella

Differential Revision: D65543676

fbshipit-source-id: e1a9a1e602786c750c7366eae671b92dcbf0f24b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants