Skip to content

Commit

Permalink
[CustomDevice] add model parallel support for custom device (#52872)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored Apr 14, 2023
1 parent 6b756e8 commit f8d0901
Show file tree
Hide file tree
Showing 8 changed files with 513 additions and 20 deletions.
12 changes: 8 additions & 4 deletions paddle/fluid/distributed/collective/process_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
return instance;
}

void ProcessGroupIdMap::DestroyProcessGroup(int gid) {
int use_count = ProcessGroupIdMap::GetInstance()[gid].use_count();
for (int i = 0; i < use_count; ++i) {
ProcessGroupIdMap::GetInstance()[gid].reset();
void ProcessGroupIdMap::DestroyProcessGroup() {
auto& id_map = ProcessGroupIdMap::GetInstance();
for (auto iter = id_map.begin(); iter != id_map.end(); ++iter) {
auto use_count = iter->second.use_count();
for (int i = 0; i < use_count; ++i) {
iter->second.reset();
}
}
id_map.clear();
}

} // namespace distributed
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/collective/process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class ProcessGroupIdMap
: public std::unordered_map<int, std::shared_ptr<ProcessGroup>> {
public:
static ProcessGroupIdMap& GetInstance();
static void DestroyProcessGroup(int gid);
static void DestroyProcessGroup();
};

// TODO(dev): The following method will be removed soon.
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ endif()
copy_if_different(${pybind_file} ${pybind_file_final})

if (WITH_CUSTOM_DEVICE)
cc_library(custom_device_common_op_registry SRCS custom_device_common_op_registry.cc DEPS operator)
cc_library(custom_device_common_op_registry SRCS custom_device_common_op_registry.cc DEPS operator phi_api)
endif()

if(NOT "${OP_LIST}" STREQUAL "")
Expand Down
502 changes: 502 additions & 0 deletions paddle/fluid/operators/custom_device_common_op_registry.cc

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion paddle/fluid/pybind/distributed_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,6 @@ void BindDistributed(py::module *m) {
*m, "ProcessGroupIdMap")
.def_static("destroy",
distributed::ProcessGroupIdMap::DestroyProcessGroup,
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
}

Expand Down
3 changes: 0 additions & 3 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from .collective import split # noqa: F401
from .collective import new_group # noqa: F401
from .collective import is_available
from .collective import _destroy_process_group_id_map
from .communication import (
stream,
ReduceOp,
Expand Down Expand Up @@ -122,5 +121,3 @@
"is_available",
"get_backend",
]

atexit.register(_destroy_process_group_id_map)
10 changes: 0 additions & 10 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,6 @@ def _set_custom_gid(gid):
_custom_gid = gid


def _destroy_process_group_id_map():
"""
Destroy the custom process group. Designed for CustomDevice.
"""
core.ProcessGroupIdMap.destroy()


def new_group(ranks=None, backend=None, timeout=_default_timeout):
"""
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,4 @@ def remove_flag_if_exists(name):
# Keep clear_kernel_factory running before clear_device_manager
atexit.register(core.clear_device_manager)
atexit.register(core.clear_kernel_factory)
atexit.register(core.ProcessGroupIdMap.destroy)

0 comments on commit f8d0901

Please sign in to comment.