Skip to content

Commit

Permalink
【Paddle Tensor No.2】新增 Tensor.__rfloordiv__ 复用已有接口 Tensor.__floordiv__ (
Browse files Browse the repository at this point in the history
#69222)

* fix

* add dygraph test

* add dygraph test
  • Loading branch information
enkilee authored Nov 12, 2024
1 parent d9c4b9f commit c8a1c14
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 3 deletions.
94 changes: 94 additions & 0 deletions paddle/fluid/pybind/eager_math_op_patch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,96 @@ static PyObject* tensor__floordiv__method(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__rfloordiv__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
phi::RecordEvent pythonc_record_event(
"__rfloordiv__ pybind_patch_func", phi::TracerEventType::UserDefined, 1);
EAGER_TRY
VLOG(6) << "Running Eager tensor__rfloordiv__method";

// Set Device ID
auto place = egr::Controller::Instance().GetExpectedPlace();
SetDevice(place);

paddle::Tensor ret;
paddle::Tensor self_tensor = self->tensor;

PyObject* other_obj = PyTuple_GET_ITEM(args, 0);

// 1. scalar exists cases or not
// there is no scalar case for rfloordiv, but alse need to cast self_tensor
// in need.
if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) ||
IsNumpyType(other_obj)) {
if (PyFloat_Check(other_obj)) {
if (_supported_int_dtype_.find(self_tensor.dtype()) !=
_supported_int_dtype_.end()) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::FLOAT32);
}
} else if (PyCheckInteger(other_obj) &&
self_tensor.dtype() == DataType::BOOL) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(self_tensor, DataType::INT64);
}
} else if (PyComplex_Check(other_obj)) {
if (is_support_complex(self_tensor.dtype()) == false) {
eager_gil_scoped_release guard;
self_tensor = cast_ad_func(
self_tensor, promoteTypes(self_tensor.dtype(), DataType::COMPLEX64));
}
}

// 2. create or get tensor for other_obj
paddle::Tensor other_tensor;
if (PyCheckTensor(other_obj)) {
auto& self_tensor_ref_addr = self->tensor;
auto& other_tensor_ref_addr = CastPyArg2Tensor(other_obj, 0);
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(
&mesh, self_tensor_ref_addr, other_tensor_ref_addr)) {
ConvertAllInputsToDistTensor(
mesh, self_tensor_ref_addr, other_tensor_ref_addr);
}
self_tensor = self_tensor_ref_addr;
other_tensor = other_tensor_ref_addr;
} else {
if (IsNumpyArray(other_obj)) {
py::object numpy_value =
py::reinterpret_borrow<py::object>(py::handle(other_obj));
other_tensor = paddle::empty({}, phi::DataType::FLOAT32, place);
InitTensorWithNumpyValue(numpy_value, place, &other_tensor);
} else {
paddle::experimental::Scalar value =
CastPyArg2Scalar(other_obj, "__rfloordiv__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place());
}
}
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) {
ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor);
}
}

// 3. calculation
VLOG(6) << "Calling floor_divide_ad_func in tensor__rfloordiv__method";
{
eager_gil_scoped_release guard;
ret = floor_divide_ad_func(other_tensor, self_tensor);
}

return ToPyObject(ret);
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__pow__method(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
Expand Down Expand Up @@ -1973,6 +2063,10 @@ PyMethodDef math_op_patch_methods[] = { // NOLINT
(PyCFunction)(void (*)())tensor__floordiv__method,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"__rfloordiv__",
(PyCFunction)(void (*)())tensor__rfloordiv__method,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"__pow__",
(PyCFunction)(void (*)())tensor__pow__method,
METH_VARARGS | METH_KEYWORDS,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/base/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def _mT_(var: Tensor) -> Tensor:
'__lt__',
'__le__',
'__floordiv__',
'__rfloordiv__',
'__pow__',
'__rpow__',
'__eq__',
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/base/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"__truediv__",
"__rtruediv__",
"__floordiv__",
"__rfloordiv__",
"__pow__",
"__rpow__",
"__eq__",
Expand All @@ -78,6 +79,7 @@
"__pow__": "A ** B",
"__rpow__": "A **= B",
"__floordiv__": "A //B",
"__rfloordiv__": "A //=B",
"__mod__": "A % B",
"__rmod__": "A %= B",
"__matmul__": "A @ B",
Expand Down Expand Up @@ -860,6 +862,12 @@ def to_dense(var):
'__floordiv__', 'elementwise_floordiv', False, None
),
),
(
'__rfloordiv__',
_binary_creator_(
'__rfloordiv__', 'elementwise_floordiv', True, None
),
),
(
'__mod__',
_binary_creator_('__mod__', 'elementwise_mod', False, None),
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"__truediv__",
"__rtruediv__",
"__floordiv__",
"__rfloordiv__",
"__pow__",
"__rpow__",
"__eq__",
Expand Down Expand Up @@ -1092,6 +1093,12 @@ def register_hook(self, hook):
'__floordiv__', paddle.tensor.floor_divide, False, None
),
),
(
'__rfloordiv__',
_binary_creator_(
'__rfloordiv__', paddle.tensor.floor_divide, True, None
),
),
(
'__mod__',
_binary_creator_('__mod__', paddle.tensor.remainder, False, None),
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/tensor/tensor.prototype.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ class AbstractTensor:
def __rtruediv__(self, y: _typing.TensorLike) -> Tensor: ... # type: ignore
def __rmod__(self, y: _typing.TensorLike) -> Tensor: ... # type: ignore
def __rpow__(self, y: _typing.TensorLike) -> Tensor: ... # type: ignore
def __rdiv__(self, y: _typing.TensorLike) -> Tensor: ...
def __rdiv__(self, y: _typing.TensorLike) -> Tensor: ... # type: ignore
def __rfloordiv__(self, y: _typing.TensorLike) -> Tensor: ... # type: ignore

# type cast
def __bool__(self) -> bool: ...
Expand Down
13 changes: 13 additions & 0 deletions test/legacy_test/test_math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,19 @@ def test_dygraph_floor_divide(self):
np.testing.assert_equal(actual_out, expect_out)
paddle.enable_static()

def test_dygraph_rfloordiv(self):
paddle.disable_static()
np_a = np.random.random((2, 3, 4)).astype(np.int32)
np_b = np.random.random((2, 3, 4)).astype(np.int32)
np_b[np.abs(np_b) < 1] = 2
# normal case: nparray // tensor
tensor_a = paddle.to_tensor(np_a, dtype="int32")
tensor_b = paddle.to_tensor(np_b, dtype="int32")
expect_out = np_b // np_a
actual_out = tensor_b.__rfloordiv__(np_a)
np.testing.assert_equal(actual_out, expect_out)
paddle.enable_static()

def test_dygraph_elementwise_pow(self):
paddle.disable_static()
self.init_data()
Expand Down
7 changes: 5 additions & 2 deletions test/legacy_test/test_math_op_patch_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def test_floordiv(self):
paddle.to_tensor(x_np), paddle.to_tensor(y_np)
)
res_np_d = x_np.__floordiv__(y_np)
res_np_e = x_np.__rfloordiv__(y_np)
paddle.enable_static()
with paddle.pir_utils.IrGuard():
main_program, exe, program_guard = new_program()
Expand All @@ -156,14 +157,16 @@ def test_floordiv(self):
b = x // y
c = x.floor_divide(y)
d = x.__floordiv__(y)
(b_np, c_np, d_np) = exe.run(
e = x.__rfloordiv__(y)
(b_np, c_np, d_np, e_np) = exe.run(
main_program,
feed={"x": x_np, "y": y_np},
fetch_list=[b, c, d],
fetch_list=[b, c, d, e],
)
np.testing.assert_allclose(res_np_b, b_np, atol=1e-05)
np.testing.assert_allclose(res_np_c, c_np, atol=1e-05)
np.testing.assert_allclose(res_np_d, d_np, atol=1e-05)
np.testing.assert_allclose(res_np_e, e_np, rtol=1e-05)

def test_bitwise_not(self):
paddle.disable_static()
Expand Down

0 comments on commit c8a1c14

Please sign in to comment.