diff --git a/botorch/optim/closures/core.py b/botorch/optim/closures/core.py index 694289d7f5..77e20d5ad0 100644 --- a/botorch/optim/closures/core.py +++ b/botorch/optim/closures/core.py @@ -85,7 +85,6 @@ def __init__( closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]], parameters: dict[str, Tensor], as_array: Callable[[Tensor], npt.NDArray] = None, # pyre-ignore [9] - as_tensor: Callable[[npt.NDArray], Tensor] = torch.as_tensor, get_state: Callable[[], npt.NDArray] = None, # pyre-ignore [9] set_state: Callable[[npt.NDArray], None] = None, # pyre-ignore [9] fill_value: float = 0.0, @@ -99,14 +98,13 @@ def __init__( Expected to correspond with the first `len(parameters)` optional gradient tensors returned by `closure`. as_array: Callable used to convert tensors to ndarrays. - as_tensor: Callable used to convert ndarrays to tensors. get_state: Callable that returns the closure's state as an ndarray. When passed as `None`, defaults to calling `get_tensors_as_ndarray_1d` on `closure.parameters` while passing `as_array` (if given by the user). set_state: Callable that takes a 1-dimensional ndarray and sets the closure's state. When passed as `None`, `set_state` defaults to calling `set_tensors_from_ndarray_1d` with `closure.parameters` and - a given ndarray while passing `as_tensor`. + a given ndarray. fill_value: Fill value for parameters whose gradients are None. In most cases, `fill_value` should either be zero or NaN. persistent: Boolean specifying whether an ndarray should be retained @@ -128,15 +126,12 @@ def __init__( as_array = partial(as_ndarray, dtype=np_float64) if set_state is None: - set_state = partial( - set_tensors_from_ndarray_1d, parameters, as_tensor=as_tensor - ) + set_state = partial(set_tensors_from_ndarray_1d, parameters) self.closure = closure self.parameters = parameters self.as_array = as_ndarray - self.as_tensor = as_tensor self._get_state = get_state self._set_state = set_state diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py index bb5d6b9093..bc8fadfe35 100644 --- a/botorch/optim/utils/numpy_utils.py +++ b/botorch/optim/utils/numpy_utils.py @@ -113,7 +113,6 @@ def get_tensors_as_ndarray_1d( def set_tensors_from_ndarray_1d( tensors: Iterator[Tensor] | dict[str, Tensor], array: npt.NDArray, - as_tensor: Callable[[npt.NDArray], Tensor] = torch.as_tensor, ) -> None: r"""Sets the values of one more tensors based off of a vector of assignments.""" named_tensors_iter = ( @@ -125,7 +124,12 @@ def set_tensors_from_ndarray_1d( try: size = tnsr.numel() vals = array[index : index + size] if tnsr.ndim else array[index] - tnsr.copy_(as_tensor(vals).to(tnsr).view(tnsr.shape).to(tnsr)) + tnsr.copy_( + torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype) + .to(tnsr) + .view(tnsr.shape) + .to(tnsr) + ) index += size except Exception as e: raise RuntimeError(