Skip to content

Commit

Permalink
PyTorch typing for tensor.dtype and tensor.half()
Browse files Browse the repository at this point in the history
Summary: Need this to unblock the next diff.

Reviewed By: fuzic

Differential Revision: D15367860

fbshipit-source-id: b321bb06d7815f49c5c2d158fd2ac1aa7d2c0d56
  • Loading branch information
chandlerzuo authored and facebook-github-bot committed May 21, 2019
1 parent d6e78aa commit eebb8c7
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 8 deletions.
1 change: 0 additions & 1 deletion ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
model: TorchModel,
transforms: List[Type[Transform]],
transform_configs: Optional[Dict[str, TConfig]] = None,
# pyre-fixme[11]: Type `dtype` is not defined.
torch_dtype: Optional[torch.dtype] = None, # noqa T484
torch_device: Optional[torch.device] = None,
status_quo_name: Optional[str] = None,
Expand Down
4 changes: 1 addition & 3 deletions ax/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,6 @@ def filter_constraints_and_fixed_features(
feas &= X_np[:, idx] == val
X_feas = X_np[feas, :]
if isinstance(X, torch.Tensor):
return torch.from_numpy(X_feas).to(
device=X.device, dtype=X.dtype # pyre-ignore
)
return torch.from_numpy(X_feas).to(device=X.device, dtype=X.dtype)
else:
return X_feas
2 changes: 1 addition & 1 deletion ax/models/torch/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def fit(
task_features: List[int],
feature_names: List[str],
) -> None:
self.dtype = Xs[0].dtype # pyre-ignore [16]
self.dtype = Xs[0].dtype
self.device = Xs[0].device
self.Xs = Xs
self.Ys = Ys
Expand Down
4 changes: 1 addition & 3 deletions ax/models/torch/botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def predict_from_model(model: Model, X: Tensor) -> Tuple[Tensor, Tensor]:
mean = posterior.mean.cpu().detach()
# TODO: Allow Posterior to (optionally) return the full covariance matrix
variance = posterior.variance.cpu().detach()
cov = variance.unsqueeze(-1) * torch.eye(
variance.shape[-1], dtype=variance.dtype # pyre-ignore
)
cov = variance.unsqueeze(-1) * torch.eye(variance.shape[-1], dtype=variance.dtype)
return mean, cov


Expand Down

0 comments on commit eebb8c7

Please sign in to comment.