Skip to content

Commit

Permalink
Partially implement to_torch()/from_torch() according to PyTorch in T…
Browse files Browse the repository at this point in the history
…aichi
  • Loading branch information
0xzhang committed Apr 29, 2022
1 parent a8fbe3f commit d48c032
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 2 deletions.
40 changes: 39 additions & 1 deletion python/taichi/lang/field.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import taichi.lang
from taichi._lib import core as _ti_core
from taichi.lang.util import python_scope, to_numpy_type, to_pytorch_type
from taichi.lang.util import python_scope, to_numpy_type, to_pytorch_type, to_paddle_type


class Field:
Expand Down Expand Up @@ -132,6 +132,18 @@ def to_torch(self, device=None):
"""
raise NotImplementedError()

@python_scope
def to_paddle(self, device=None):
"""Converts `self` to a paddle tensor.
Args:
device (paddle.CPUPlace()/CUDAPlace(), optional): The desired device of returned tensor.
Returns:
paddle.Tensor: The result paddle tensor.
"""
raise NotImplementedError()

@python_scope
def from_numpy(self, arr):
"""Loads all elements from a numpy array.
Expand All @@ -154,6 +166,17 @@ def from_torch(self, arr):
"""
self.from_numpy(arr.contiguous())

@python_scope
def from_paddle(self, arr):
"""Loads all elements from a paddle tensor.
The shape of the paddle tensor needs to be the same as `self`.
Args:
arr (paddle.Tensor): The source paddle tensor.
"""
self.from_numpy(arr)

@python_scope
def copy_from(self, other):
"""Copies all elements from another field.
Expand Down Expand Up @@ -267,6 +290,21 @@ def to_torch(self, device=None):
taichi.lang.runtime_ops.sync()
return arr

@python_scope
def to_paddle(self, device=None):
"""Converts this field to a `paddle.Tensor`.
"""
import paddle # pylint: disable=C0415

# pylint: disable=E1101
arr = paddle.zeros(size=self.shape,
dtype=to_paddle_type(self.dtype),
device=device)
from taichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
tensor_to_ext_arr(self, arr)
taichi.lang.runtime_ops.sync()
return arr

@python_scope
def from_numpy(self, arr):
"""Copies the data from a `numpy.ndarray` into this field.
Expand Down
25 changes: 24 additions & 1 deletion python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from taichi.lang.field import Field, ScalarField, SNodeHostAccess
from taichi.lang.swizzle_generator import SwizzleGenerator
from taichi.lang.util import (cook_dtype, in_python_scope, python_scope,
taichi_scope, to_numpy_type, to_pytorch_type,
taichi_scope, to_numpy_type, to_pytorch_type, to_paddle_type,
warning)
from taichi.types import primitive_types
from taichi.types.compound_types import CompoundType
Expand Down Expand Up @@ -1455,6 +1455,29 @@ def to_torch(self, device=None, keep_dims=False):
runtime_ops.sync()
return arr

def to_paddle(self, device=None, keep_dims=False):
"""Converts the field instance to a PaddlePaddle tensor.
Args:
device (paddle.CPUPlace()/CUDAPlace(), optional): The desired device of returned tensor.
keep_dims (bool, optional): Whether to keep the dimension after conversion.
See :meth:`~taichi.lang.field.MatrixField.to_numpy` for more detailed explanation.
Returns:
paddle.Tensor: The result paddle tensor.
"""
import paddle # pylint: disable=C0415
as_vector = self.m == 1 and not keep_dims
shape_ext = (self.n, ) if as_vector else (self.n, self.m)
# pylint: disable=E1101
arr = paddle.empty(self.shape + shape_ext,
dtype=to_paddle_type(self.dtype),
device=device)
from taichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
matrix_to_ext_arr(self, arr, as_vector)
runtime_ops.sync()
return arr

@python_scope
def from_numpy(self, arr):
"""Copies an `numpy.ndarray` into this field.
Expand Down
9 changes: 9 additions & 0 deletions python/taichi/lang/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ def from_torch(self, array_dict):
for k, v in self._items:
v.from_torch(array_dict[k])

@python_scope
def from_paddle(self, array_dict):
for k, v in self._items:
v.from_paddle(array_dict[k])

@python_scope
def to_numpy(self):
return {k: v.to_numpy() for k, v in self._items}
Expand All @@ -179,6 +184,10 @@ def to_numpy(self):
def to_torch(self, device=None):
return {k: v.to_torch(device=device) for k, v in self._items}

@python_scope
def to_paddle(self, device=None):
return {k: v.to_paddle(device=device) for k, v in self._items}

@python_scope
def __len__(self):
return _ti_core.get_num_elements(self.mesh.mesh_ptr, self._type)
Expand Down
27 changes: 27 additions & 0 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,17 @@ def from_torch(self, array_dict):
for k, v in self._items:
v.from_torch(array_dict[k])

@python_scope
def from_paddle(self, array_dict):
"""Copies the data from a set of `paddle.Tensor` into this field.
The argument `array_dict` must be a dictionay-like object, it
contains all the keys in this field and the copying process
between corresponding items can be performed.
"""
for k, v in self._items:
v.from_paddle(array_dict[k])

@python_scope
def to_numpy(self):
"""Converts the Struct field instance to a dictionary of NumPy arrays.
Expand Down Expand Up @@ -531,6 +542,22 @@ def to_torch(self, device=None):
"""
return {k: v.to_torch(device=device) for k, v in self._items}

@python_scope
def to_paddle(self, device=None):
"""Converts the Struct field instance to a dictionary of PaddlePaddle tensors.
The dictionary may be nested when converting nested structs.
Args:
device (paddle.CPUPlace()/CUDAPlace(), optional): The
desired device of returned tensor.
Returns:
Dict[str, Union[paddle.Tensor, Dict]]: The result
PaddlePaddle tensor.
"""
return {k: v.to_paddle(device=device) for k, v in self._items}

@python_scope
def __setitem__(self, indices, element):
self._initialize_host_accessors()
Expand Down

0 comments on commit d48c032

Please sign in to comment.