Skip to content

Commit

Permalink
Remove as_tensor argument of set_tensors_from_ndarray_1d
Browse files Browse the repository at this point in the history
  • Loading branch information
AVHopp committed Nov 6, 2024
1 parent 3ca48d0 commit 49a65d9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
9 changes: 2 additions & 7 deletions botorch/optim/closures/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]],
parameters: dict[str, Tensor],
as_array: Callable[[Tensor], npt.NDArray] = None, # pyre-ignore [9]
as_tensor: Callable[[npt.NDArray], Tensor] = torch.as_tensor,
get_state: Callable[[], npt.NDArray] = None, # pyre-ignore [9]
set_state: Callable[[npt.NDArray], None] = None, # pyre-ignore [9]
fill_value: float = 0.0,
Expand All @@ -99,14 +98,13 @@ def __init__(
Expected to correspond with the first `len(parameters)` optional
gradient tensors returned by `closure`.
as_array: Callable used to convert tensors to ndarrays.
as_tensor: Callable used to convert ndarrays to tensors.
get_state: Callable that returns the closure's state as an ndarray. When
passed as `None`, defaults to calling `get_tensors_as_ndarray_1d`
on `closure.parameters` while passing `as_array` (if given by the user).
set_state: Callable that takes a 1-dimensional ndarray and sets the
closure's state. When passed as `None`, `set_state` defaults to
calling `set_tensors_from_ndarray_1d` with `closure.parameters` and
a given ndarray while passing `as_tensor`.
a given ndarray.
fill_value: Fill value for parameters whose gradients are None. In most
cases, `fill_value` should either be zero or NaN.
persistent: Boolean specifying whether an ndarray should be retained
Expand All @@ -128,15 +126,12 @@ def __init__(
as_array = partial(as_ndarray, dtype=np_float64)

if set_state is None:
set_state = partial(
set_tensors_from_ndarray_1d, parameters, as_tensor=as_tensor
)
set_state = partial(set_tensors_from_ndarray_1d, parameters)

self.closure = closure
self.parameters = parameters

self.as_array = as_ndarray
self.as_tensor = as_tensor
self._get_state = get_state
self._set_state = set_state

Expand Down
8 changes: 6 additions & 2 deletions botorch/optim/utils/numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def get_tensors_as_ndarray_1d(
def set_tensors_from_ndarray_1d(
tensors: Iterator[Tensor] | dict[str, Tensor],
array: npt.NDArray,
as_tensor: Callable[[npt.NDArray], Tensor] = torch.as_tensor,
) -> None:
r"""Sets the values of one more tensors based off of a vector of assignments."""
named_tensors_iter = (
Expand All @@ -125,7 +124,12 @@ def set_tensors_from_ndarray_1d(
try:
size = tnsr.numel()
vals = array[index : index + size] if tnsr.ndim else array[index]
tnsr.copy_(as_tensor(vals).to(tnsr).view(tnsr.shape).to(tnsr))
tnsr.copy_(
torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype)
.to(tnsr)
.view(tnsr.shape)
.to(tnsr)
)
index += size
except Exception as e:
raise RuntimeError(
Expand Down

0 comments on commit 49a65d9

Please sign in to comment.