Skip to content

Commit

Permalink
[CustomDevice] add eager mode support (#42034)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored Apr 24, 2022
1 parent 0e0f7da commit ccafd2e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
5 changes: 4 additions & 1 deletion paddle/fluid/pybind/eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,13 @@ void InitTensorWithNumpyValue(TensorObject* self, const py::object& array,
zero_copy);
} else if (platform::is_npu_place(place)) {
SetTensorFromPyArray<platform::NPUPlace>(impl_ptr, array, place, zero_copy);
} else if (platform::is_custom_place(place)) {
SetTensorFromPyArray<platform::CustomPlace>(impl_ptr, array, place,
zero_copy);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Place should be one of "
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace"));
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace/CustomPlace"));
}
}

Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ extern PyTypeObject* g_cpuplace_pytype;
extern PyTypeObject* g_xpuplace_pytype;
extern PyTypeObject* g_npuplace_pytype;
extern PyTypeObject* g_cudapinnedplace_pytype;
extern PyTypeObject* g_customplace_pytype;
extern PyTypeObject* g_framework_tensor_pytype;
extern PyTypeObject* g_framework_lodtensorarray_pytype;
extern PyTypeObject* g_custom_op_kernel_ctx_pytype;
Expand Down Expand Up @@ -377,10 +378,15 @@ platform::Place CastPyArg2Place(PyObject* obj, ssize_t arg_pos) {
} else if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_cudapinnedplace_pytype))) {
place = ::pybind11::handle(obj).cast<platform::CUDAPinnedPlace>();
} else if (PyObject_IsInstance(
obj, reinterpret_cast<PyObject*>(g_customplace_pytype))) {
place = ::pybind11::handle(obj).cast<platform::CustomPlace>();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"one of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace), "
"one "
"of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace,"
"CustomPlace), "
"but got %s",
arg_pos + 1, reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
Expand Down
9 changes: 6 additions & 3 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ PyTypeObject *g_xpuplace_pytype = nullptr;
PyTypeObject *g_npuplace_pytype = nullptr;
PyTypeObject *g_cudapinnedplace_pytype = nullptr;
PyTypeObject *g_mluplace_pytype = nullptr;
PyTypeObject *g_customplace_pytype = nullptr;
PyTypeObject *g_framework_tensor_pytype = nullptr;
PyTypeObject *g_framework_lodtensorarray_pytype = nullptr;
PyTypeObject *g_custom_op_kernel_ctx_pytype = nullptr;
Expand Down Expand Up @@ -2125,8 +2126,8 @@ All parameter, weight, gradient are variables in Paddle.
#endif
return devices;
});
py::class_<platform::CustomPlace>(m, "CustomPlace",
R"DOC(
py::class_<platform::CustomPlace> customplace(m, "CustomPlace",
R"DOC(
CustomPlace is a descriptor of a device.
It represents a custom device on which a tensor will be allocated and a model will run.
Expand All @@ -2135,7 +2136,9 @@ All parameter, weight, gradient are variables in Paddle.
import paddle
fake_cpu_place = paddle.CustomPlace("FakeCPU", 0)
)DOC")
)DOC");
g_customplace_pytype = reinterpret_cast<PyTypeObject *>(customplace.ptr());
customplace
.def("__init__",
[](platform::CustomPlace &self, const std::string &device_type,
int dev_id) {
Expand Down

0 comments on commit ccafd2e

Please sign in to comment.