-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
【Hackathon 5th No.124】Support r to x on 1D one-device-to-multiple mesh (
#60281) * wip: reshard r2x * fix: retrieve util func * style: reformat code * fix: cannot build
- Loading branch information
Showing
11 changed files
with
347 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
138 changes: 138 additions & 0 deletions
138
paddle/phi/core/distributed/auto_parallel/reshard/r_to_x_reshard_function.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/core/distributed/auto_parallel/reshard/r_to_x_reshard_function.h" | ||
|
||
#include "glog/logging.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" | ||
#include "paddle/phi/core/distributed/store/store_utils.h" | ||
#include "paddle/phi/kernels/add_n_kernel.h" | ||
#include "paddle/phi/kernels/concat_kernel.h" | ||
#include "paddle/phi/kernels/elementwise_add_kernel.h" | ||
#include "paddle/phi/kernels/full_kernel.h" | ||
#include "paddle/phi/kernels/p_recv_kernel.h" | ||
#include "paddle/phi/kernels/p_send_kernel.h" | ||
#include "paddle/phi/kernels/split_kernel.h" | ||
|
||
namespace phi { | ||
namespace distributed { | ||
|
||
bool RToXExpandReshardFunction::IsSuitable( | ||
const DistTensor& in, const TensorDistAttr& out_dist_attr) { | ||
const auto& in_dist_attr = in.dist_attr(); | ||
|
||
RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_replicated()); | ||
|
||
const auto& in_process_mesh = in_dist_attr.process_mesh(); | ||
const auto& out_process_mesh = out_dist_attr.process_mesh(); | ||
|
||
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); | ||
RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); | ||
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.process_ids().size() == 1); | ||
RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.process_ids().size() != 1); | ||
|
||
return true; | ||
} | ||
|
||
void RToXExpandReshardFunction::Eval(phi::DeviceContext* dev_ctx, | ||
const DistTensor& in, | ||
const TensorDistAttr& out_dist_attr, | ||
DistTensor* out) { | ||
VLOG(3) << "Call RToXExpandReshardFunction Eval"; | ||
const auto& in_dist_attr = in.dist_attr(); | ||
const auto& out_dims_mapping = out_dist_attr.dims_mapping(); | ||
const auto& in_mesh = in_dist_attr.process_mesh(); | ||
const auto& out_mesh = out_dist_attr.process_mesh(); | ||
const auto& in_process_ids = in_mesh.process_ids(); | ||
const auto& out_process_ids = out_mesh.process_ids(); | ||
int64_t cur_global_rank = GetCurGlobalRank(); | ||
int64_t root_rank = in_process_ids[0]; | ||
auto all_process_ids = GetUnionProcessIds(in_process_ids, out_process_ids); | ||
bool dynamic_shape = true; | ||
auto dtype = in.dtype(); | ||
const auto& out_partial_status = out_dist_attr.partial_status(); | ||
bool cur_rank_in_out_mesh = | ||
(std::find(out_process_ids.begin(), | ||
out_process_ids.end(), | ||
cur_global_rank) != out_process_ids.end()); | ||
DenseTensor result_value; | ||
|
||
if (root_rank == cur_global_rank) { | ||
for (size_t i = 0; i < out_process_ids.size(); ++i) { | ||
if (out_process_ids[i] != root_rank) { | ||
RESHARD_FUNCTOR_WITH_COMM(dev_ctx, | ||
PSendKernel, | ||
dtype, | ||
all_process_ids, | ||
in.value(), | ||
out_process_ids[i], | ||
dynamic_shape); | ||
} | ||
} | ||
if (cur_rank_in_out_mesh) { | ||
result_value = in.value(); | ||
} | ||
} else { | ||
RESHARD_FUNCTOR_WITH_COMM(dev_ctx, | ||
PRecv, | ||
dtype, | ||
all_process_ids, | ||
root_rank, | ||
dynamic_shape, | ||
&result_value); | ||
} | ||
|
||
if (cur_rank_in_out_mesh) { | ||
if (out_dist_attr.is_partial()) { | ||
auto out_reduce_type = out_partial_status.at(0); | ||
if (out_reduce_type == ReduceType::kRedSum && | ||
cur_global_rank != out_process_ids[0]) { | ||
IntArray shape(result_value.dims().Get(), result_value.dims().size()); | ||
RESHARD_FUNCTOR(dev_ctx, Full, dtype, shape, 0, &result_value); | ||
} | ||
SetValue(out, result_value); | ||
} else if (out_dist_attr.is_shard()) { | ||
std::map<int, int64_t> split_axis_to_mesh_axis = | ||
GetSplitAxisWithDimsMapping(out_dims_mapping); | ||
std::vector<int64_t> coord_in_mesh = GetCurRankCoordInMesh(out_mesh); | ||
|
||
int split_axis = split_axis_to_mesh_axis.begin()->first; | ||
int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second; | ||
int64_t num_of_process = out_mesh.shape()[mesh_axis]; | ||
|
||
std::vector<int64_t> split_num_vec = | ||
BalancedSplit(in.dims()[split_axis], num_of_process); | ||
IntArray sections(split_num_vec); | ||
|
||
std::vector<DenseTensor> split_out_vec; | ||
RESHARD_FUNCTOR(dev_ctx, | ||
Split, | ||
dtype, | ||
result_value, | ||
sections, | ||
split_axis, | ||
&split_out_vec); | ||
|
||
SetValue(out, split_out_vec[coord_in_mesh[mesh_axis]]); | ||
} else { | ||
SetValue(out, result_value); | ||
} | ||
SetDistProps(out, in.dims(), out_dist_attr); | ||
} | ||
} | ||
|
||
} // namespace distributed | ||
} // namespace phi |
36 changes: 36 additions & 0 deletions
36
paddle/phi/core/distributed/auto_parallel/reshard/r_to_x_reshard_function.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.h" | ||
|
||
namespace phi { | ||
namespace distributed { | ||
|
||
class RToXExpandReshardFunction final : public ReshardFunction { | ||
public: | ||
bool IsSuitable(const DistTensor& in, | ||
const TensorDistAttr& out_dist_attr) override; | ||
|
||
void Eval(DeviceContext* dev_ctx, | ||
const DistTensor& in, | ||
const TensorDistAttr& out_dist_attr, | ||
DistTensor* out) override; | ||
|
||
std::string Name() override { return "RToXExpandReshard"; } | ||
}; | ||
|
||
} // namespace distributed | ||
} // namespace phi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
import numpy as np | ||
|
||
import paddle | ||
import paddle.distributed as dist | ||
from paddle.framework import core | ||
|
||
|
||
class TestReshardRToX: | ||
def __init__(self): | ||
self._shape = eval(os.getenv("shape")) | ||
self._dtype = os.getenv("dtype") | ||
self._seeds = eval(os.getenv("seeds")) | ||
self._shard = eval(os.getenv("shard")) | ||
self._backend = os.getenv("backend") | ||
self._in_mesh = dist.ProcessMesh([0], dim_names=["x"]) | ||
self._out_mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) | ||
|
||
def _set_place(self): | ||
if self._backend == "cpu": | ||
paddle.set_device("cpu") | ||
place = paddle.CPUPlace() | ||
elif self._backend == "gpu": | ||
place = paddle.CUDAPlace(dist.get_rank()) | ||
dev_ctx = core.DeviceContext.create(place) | ||
|
||
def test_r_to_s(self): | ||
self._set_place() | ||
|
||
a = paddle.ones(self._shape) | ||
input_tensor = dist.shard_tensor(a, self._in_mesh, [dist.Replicate()]) | ||
out = dist.reshard( | ||
input_tensor, self._out_mesh, [dist.Shard(self._shard)] | ||
) | ||
|
||
out_shape = list(self._shape) | ||
if out_shape[self._shard] % 2 == 0: | ||
out_shape[self._shard] = out_shape[self._shard] // 2 | ||
np.testing.assert_equal(out.numpy(), a.numpy()) | ||
else: | ||
out_shape[self._shard] = ( | ||
out_shape[self._shard] // 2 | ||
if dist.get_rank() == 1 | ||
else out_shape[self._shard] // 2 + 1 | ||
) | ||
assert np.equal(out.shape, input_tensor.shape).all() | ||
assert np.equal(out._local_shape, out_shape).all() | ||
|
||
def test_r_to_r(self): | ||
self._set_place() | ||
|
||
a = paddle.ones(self._shape) | ||
input_tensor = dist.shard_tensor(a, self._in_mesh, [dist.Replicate()]) | ||
out = dist.reshard(input_tensor, self._out_mesh, [dist.Replicate()]) | ||
|
||
if dist.get_rank() == 0: | ||
assert np.equal(out.shape, input_tensor.shape).all() | ||
np.testing.assert_equal(out._local_value().numpy(), a.numpy()) | ||
|
||
def test_r_to_p(self): | ||
self._set_place() | ||
|
||
a = paddle.ones(self._shape) | ||
input_tensor = dist.shard_tensor(a, self._in_mesh, [dist.Replicate()]) | ||
out = dist.reshard( | ||
input_tensor, | ||
self._out_mesh, | ||
[dist.Partial(dist.ReduceType.kRedSum)], | ||
) | ||
|
||
if dist.get_rank() == 0: | ||
np.testing.assert_equal( | ||
out._local_value().numpy(), input_tensor.numpy() | ||
) | ||
else: | ||
zeros = paddle.zeros(self._shape) | ||
np.testing.assert_equal(out._local_value().numpy(), zeros.numpy()) | ||
|
||
assert np.equal(out.shape, input_tensor.shape).all() | ||
assert np.equal(out._local_shape, input_tensor._local_shape).all() | ||
|
||
def run_test_case(self): | ||
self.test_r_to_s() | ||
self.test_r_to_r() | ||
self.test_r_to_p() | ||
|
||
|
||
if __name__ == '__main__': | ||
TestReshardRToX().run_test_case() |
Oops, something went wrong.