Skip to content

Commit

Permalink
RF, add some device logic
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jul 18, 2023
1 parent 88cdf02 commit 252acef
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 8 deletions.
2 changes: 2 additions & 0 deletions returnn/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,13 +653,15 @@ def convert_to_tensor(
dims: Sequence[Dim],
dtype: str,
sparse_dim: Optional[Dim] = None,
device: Optional[str] = None,
name: Optional[str] = None,
) -> Tensor[T]:
"""
:param value: tensor, or scalar raw tensor or some other scalar value
:param dims:
:param dtype:
:param sparse_dim:
:param device:
:param name:
:return: tensor
"""
Expand Down
7 changes: 5 additions & 2 deletions returnn/frontend/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,16 @@ def bin_op_out_template(
:return: out, a, b
"""
src_dtype = None
src_device = None
if isinstance(a, Tensor):
src_dtype = a.dtype
src_device = a.device
elif isinstance(b, Tensor):
src_dtype = b.dtype
a = rf.convert_to_tensor(a, dtype=src_dtype, _backend=backend)
src_device = b.device
a = rf.convert_to_tensor(a, dtype=src_dtype, device=src_device, _backend=backend)
src_dtype = src_dtype or a.dtype
b = rf.convert_to_tensor(b, dtype=src_dtype, _backend=backend)
b = rf.convert_to_tensor(b, dtype=src_dtype, device=src_device, _backend=backend)
# sanity checks
# noinspection PyProtectedMember
assert a._raw_backend == b._raw_backend, "Cannot combine tensors from two different frontends, e.g. TF and PT"
Expand Down
6 changes: 5 additions & 1 deletion returnn/frontend/array_.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def convert_to_tensor(
dtype: Optional[str] = None,
sparse_dim: Optional[Dim] = None,
shape: Sequence[Dim] = None,
device: Optional[str] = None,
name: Optional[str] = None,
_backend: Optional[Type[Backend]] = None,
) -> Tensor[T]:
Expand All @@ -56,6 +57,7 @@ def convert_to_tensor(
:param sparse_dim:
:param shape: alias for dims, for some older code
:param name:
:param device:
:param _backend:
:return: tensor
"""
Expand Down Expand Up @@ -100,7 +102,9 @@ def convert_to_tensor(
]
if dtype is None:
dtype = value_backend.get_dtype_name_raw(value)
return _backend.convert_to_tensor(value=value, dims=dims, dtype=dtype, sparse_dim=sparse_dim, name=name)
return _backend.convert_to_tensor(
value=value, dims=dims, dtype=dtype, sparse_dim=sparse_dim, device=device, name=name
)


constant = convert_to_tensor # alias for some older code
Expand Down
28 changes: 28 additions & 0 deletions returnn/frontend/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@

from __future__ import annotations
from typing import Optional
from contextlib import contextmanager
from returnn.tensor import Tensor


__all__ = ["copy_to_device", "get_default_device", "set_default_device_ctx"]


_default_device: Optional[str] = None


def copy_to_device(x: Tensor, device: Optional[str]) -> Tensor:
"""
Copy tensor to device.
Expand All @@ -24,3 +31,24 @@ def copy_to_device(x: Tensor, device: Optional[str]) -> Tensor:
return x
# noinspection PyProtectedMember
return x._raw_backend.copy_to_device(x, device)


def get_default_device() -> Optional[str]:
"""
:return: default device, where to put new tensors (via random number generators, constant, range_over_dim, etc)
"""
return _default_device


@contextmanager
def set_default_device_ctx(device: Optional[str]):
"""
:param device: see :func:`get_default_device`
"""
global _default_device
old_device = _default_device
try:
_default_device = device
yield
finally:
_default_device = old_device
1 change: 1 addition & 0 deletions returnn/tf/frontend_layers/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def convert_to_tensor(
dims: Sequence[Dim],
dtype: str,
sparse_dim: Optional[Dim] = None,
device: Optional[str] = None,
name: Optional[str] = None,
) -> Tensor[Layer]:
"""convert to tensor"""
Expand Down
2 changes: 2 additions & 0 deletions returnn/tf/frontend_low_level/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,15 @@ def convert_to_tensor(
dims: Sequence[Dim],
dtype: str,
sparse_dim: Optional[Dim] = None,
device: Optional[str] = None,
name: Optional[str] = None,
) -> _TT:
"""
:param value:
:param dims:
:param dtype:
:param sparse_dim:
:param device:
:param name:
:return: tensor
"""
Expand Down
25 changes: 20 additions & 5 deletions returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,11 @@ def create_parameter_raw(tensor: rf.Parameter) -> torch.nn.Parameter:
:return: parameter
"""
assert all(d.is_static() for d in tensor.dims)
data = torch.zeros([d.dimension for d in tensor.dims], dtype=TorchBackend.as_dtype_raw(tensor.dtype))
data = torch.zeros(
[d.dimension for d in tensor.dims],
dtype=TorchBackend.as_dtype_raw(tensor.dtype),
device=rf.get_default_device(),
)
if tensor.dtype.startswith("int"):
requires_grad = False
else:
Expand Down Expand Up @@ -645,13 +649,15 @@ def convert_to_tensor(
dims: Sequence[Dim],
dtype: str,
sparse_dim: Optional[Dim] = None,
device: Optional[str] = None,
name: Optional[str] = None,
) -> Tensor[torch.Tensor]:
"""
:param value:
:param dims:
:param dtype:
:param sparse_dim:
:param device:
:param name:
:return: tensor
"""
Expand All @@ -661,7 +667,12 @@ def convert_to_tensor(
name = name or "raw_tensor"
else:
name = name or "const"
value = torch.tensor(value, dtype=TorchBackend.as_dtype_raw(dtype))
value = torch.tensor(
value,
dtype=TorchBackend.as_dtype_raw(dtype),
# Keep scalars on CPU.
device=(device or rf.get_default_device()) if dims else "cpu",
)
assert isinstance(value, torch.Tensor)
return Tensor(name, dims=dims, dtype=dtype, sparse_dim=sparse_dim, raw_tensor=value)

Expand All @@ -682,7 +693,9 @@ def full(
# onnx::ConstantOfShape (via torch.full) must get shape as int64.
# https://github.com/rwth-i6/returnn/issues/1333#issuecomment-1607236783
shape = [dim.long() if isinstance(dim, torch.Tensor) else dim for dim in shape]
raw_tensor = torch.full(shape, fill_value, dtype=TorchBackend.as_dtype_raw(dtype))
raw_tensor = torch.full(
shape, fill_value, dtype=TorchBackend.as_dtype_raw(dtype), device=rf.get_default_device()
)
return Tensor(
"full", dims=dims, sparse_dim=sparse_dim, feature_dim=feature_dim, dtype=dtype, raw_tensor=raw_tensor
)
Expand Down Expand Up @@ -934,7 +947,9 @@ def range_over_dim(dim: Dim, *, dtype: Optional[str] = None) -> Tensor[torch.Ten
sparse_dim=dim,
dtype=dtype,
)
out.raw_tensor = torch.arange(dim.get_dim_value(), dtype=TorchBackend.as_dtype_raw(out.dtype))
out.raw_tensor = torch.arange(
dim.get_dim_value(), dtype=TorchBackend.as_dtype_raw(out.dtype), device=rf.get_default_device()
)
return out

@staticmethod
Expand Down Expand Up @@ -1084,7 +1099,7 @@ def random(
out = Tensor(
name=f"random_{distribution}", dims=dims, dtype=dtype, sparse_dim=sparse_dim, feature_dim=feature_dim
)
out.raw_tensor = torch.empty(shape, dtype=dtype_)
out.raw_tensor = torch.empty(shape, dtype=dtype_, device=rf.get_default_device())
assert explicit_state is None # not implemented otherwise
generator = None # using the global default from PT
assert isinstance(static, bool)
Expand Down

0 comments on commit 252acef

Please sign in to comment.