From 3ce483b52ef4c696dccd9534ccc91998432101de Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Fri, 1 Mar 2024 10:10:24 +0800 Subject: [PATCH] [PIR] add distributed dialect. (#61978) --- paddle/fluid/pir/dialect/CMakeLists.txt | 6 + .../distributed/ir/attribute_storage.h | 118 ++++++++++++++++ .../dialect/distributed/ir/dist_attribute.cc | 73 ++++++++++ .../dialect/distributed/ir/dist_attribute.h | 101 ++++++++++++++ .../dialect/distributed/ir/dist_dialect.cc | 62 +++++++++ .../pir/dialect/distributed/ir/dist_dialect.h | 41 ++++++ .../pir/dialect/distributed/ir/dist_type.cc | 43 ++++++ .../pir/dialect/distributed/ir/dist_type.h | 61 +++++++++ .../pir/dialect/distributed/ir/type_storage.h | 81 +++++++++++ paddle/fluid/pybind/pybind.cc | 3 + paddle/pir/include/core/attribute.h | 7 +- paddle/pir/include/core/attribute_base.h | 12 +- paddle/pir/include/core/storage_manager.h | 2 +- .../include/core/storage_manager_support.h | 8 +- paddle/pir/include/core/type.h | 8 +- test/cpp/pir/CMakeLists.txt | 1 + test/cpp/pir/distributed/CMakeLists.txt | 3 + test/cpp/pir/distributed/dist_dialect_test.cc | 127 ++++++++++++++++++ 18 files changed, 743 insertions(+), 14 deletions(-) create mode 100644 paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h create mode 100644 paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc create mode 100644 paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h create mode 100644 paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc create mode 100644 paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h create mode 100644 paddle/fluid/pir/dialect/distributed/ir/dist_type.cc create mode 100644 paddle/fluid/pir/dialect/distributed/ir/dist_type.h create mode 100644 paddle/fluid/pir/dialect/distributed/ir/type_storage.h create mode 100644 test/cpp/pir/distributed/CMakeLists.txt create mode 100644 test/cpp/pir/distributed/dist_dialect_test.cc diff --git a/paddle/fluid/pir/dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/CMakeLists.txt index 2955a6d57afb5..d5050b49ac582 100644 --- a/paddle/fluid/pir/dialect/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/CMakeLists.txt @@ -255,6 +255,12 @@ if(WITH_MKLDNN) ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_onednn_op.cc) endif() +file(GLOB_RECURSE dist_dialect_srcs + "${CMAKE_CURRENT_SOURCE_DIR}/distributed/ir/*.cc") + +if(WITH_DISTRIBUTE) + set(op_dialect_srcs ${op_dialect_srcs} ${dist_dialect_srcs}) +endif() set(op_dialect_deps phi common pir type_info string_helper) cc_library( diff --git a/paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h b/paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h new file mode 100644 index 0000000000000..f572e5dae762b --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h @@ -0,0 +1,118 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/common/ddim.h" +#include "paddle/common/hash_funcs.h" +#include "paddle/common/layout.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/phi/common/reduce_type.h" +#include "paddle/pir/include/core/attribute_base.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/core/utils.h" +#include "paddle/utils/flat_hash_map.h" + +namespace paddle { +namespace dialect { + +struct ProcessMeshAttrStorage : public pir::AttributeStorage { + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = phi::distributed::ProcessMesh; + + ProcessMeshAttrStorage(ParamKey&& process_mesh) // NOLINT + : process_mesh(std::move(process_mesh)) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static ProcessMeshAttrStorage* Construct(ParamKey&& key) { + return new ProcessMeshAttrStorage(std::move(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { return key.hash(); } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return process_mesh == key && process_mesh.dim_names() == key.dim_names(); + } + + ParamKey process_mesh; +}; + +struct TensorDistAttrStorage : public pir::AttributeStorage { + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = std::tuple, + flat_hash_map>; + + TensorDistAttrStorage(ParamKey&& param) // NOLINT + : process_mesh(std::get<0>(param)), + dims_mapping(std::move(std::get<1>(param))), + partial_status(std::move(std::get<2>(param))) {} + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static TensorDistAttrStorage* Construct(ParamKey&& key) { + return new TensorDistAttrStorage(std::move(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + auto mesh_hash = std::get<0>(key).hash(); + auto dims_map_hash = std::hash>()(std::get<1>(key)); + std::string partial_status_str = "["; + for (auto& itr : std::get<2>(key)) { + partial_status_str += + "Partial(dims:" + std::to_string(itr.first) + ", " + + phi::ReduceTypeStrings[static_cast(itr.second)] + "), "; + } + partial_status_str += "]"; + auto combine_hash = pir::detail::hash_combine(mesh_hash, dims_map_hash); + return pir::detail::hash_combine( + combine_hash, std::hash()(partial_status_str)); + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return process_mesh == std::get<0>(key) && + dims_mapping == std::get<1>(key) && + partial_status == std::get<2>(key); + } + + ProcessMeshAttribute process_mesh; + std::vector dims_mapping; + // partial map would less or equal than to mesh.size. + // iterate operation (copy and comparison) would more frequency than random + // element access. + flat_hash_map partial_status; +}; + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc new file mode 100644 index 0000000000000..372d6206c2be8 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h" +namespace paddle { +namespace dialect { +/// +/// \brief ProcessMeshAttribute interface. +/// +const phi::distributed::ProcessMesh& ProcessMeshAttribute::process_mesh() + const { + return storage()->process_mesh; +} +ProcessMeshAttribute ProcessMeshAttribute::get( + pir::IrContext* ctx, const phi::distributed::ProcessMesh& mesh) { + return Base::get(ctx, mesh); +} +ProcessMeshAttribute ProcessMeshAttribute::get( + pir::IrContext* ctx, + const std::vector& shape, + const std::vector& process_ids, + const std::vector& dim_names) { + return Base::get(ctx, shape, process_ids, dim_names); +} + +/// +/// \brief TensorDistAttribute interface. +/// +ProcessMeshAttribute TensorDistAttribute::mesh_attr() const { + return storage()->process_mesh; +} +const std::vector& TensorDistAttribute::dims_mapping() const { + return storage()->dims_mapping; +} + +std::set TensorDistAttribute::partial_dims() const { + auto& partial = partial_status(); + std::set keys; + for (auto& kv : partial) { + keys.emplace(kv.first); + } + return keys; +} + +const flat_hash_map& +TensorDistAttribute::partial_status() const { + return storage()->partial_status; +} + +TensorDistAttribute TensorDistAttribute::get( + pir::IrContext* ctx, + ProcessMeshAttribute mesh, + const std::vector& dims_mapping, + const flat_hash_map& partial_status) { + return Base::get(ctx, mesh, dims_mapping, partial_status); +} + +} // namespace dialect +} // namespace paddle +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ProcessMeshAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::TensorDistAttribute) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h b/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h new file mode 100644 index 0000000000000..1ee05404a3df9 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h @@ -0,0 +1,101 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/reduce_type.h" +#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" +#include "paddle/pir/include/core/attribute.h" +#include "paddle/pir/include/core/builtin_attribute_storage.h" +#include "paddle/pir/include/core/utils.h" +#include "paddle/utils/flat_hash_map.h" + +namespace paddle { +namespace dialect { +class ProcessMeshAttrStorage; +class TensorDistAttrStorage; + +class ProcessMeshAttribute : public pir::AttrBase { + public: + using Base::Base; + const phi::distributed::ProcessMesh& process_mesh() const; + const std::vector& shape() const { return process_mesh().shape(); } + const std::vector& process_ids() const { + return process_mesh().process_ids(); + } + const std::vector& dim_names() const { + return process_mesh().dim_names(); + } + int64_t size() const { return process_mesh().size(); } + int64_t ndim() const { return process_mesh().ndim(); } + int64_t dim_size(int64_t dim) const { return process_mesh().dim_size(dim); } + int64_t dim_size(const std::string& dim_name) const { + return process_mesh().dim_size(dim_name); + } + bool empty() const { return process_mesh().empty(); } + bool contains(int64_t process_id) const { + return process_mesh().contains(process_id); + } + size_t hash() const { return process_mesh().hash(); } + + std::string to_string() const { return process_mesh().to_string(); } + + static ProcessMeshAttribute get(pir::IrContext* ctx, + const phi::distributed::ProcessMesh& mesh); + static ProcessMeshAttribute get(pir::IrContext* ctx, + const std::vector& shape, + const std::vector& process_ids, + const std::vector& dim_names); +}; + +class TensorDistAttribute : public pir::AttrBase { + public: + using Base::Base; + ProcessMeshAttribute mesh_attr() const; + const phi::distributed::ProcessMesh& process_mesh() const { + return mesh_attr().process_mesh(); + } + const std::vector& dims_mapping() const; + + // return vector of mesh dims on which the this tensor is partial on + std::set partial_dims() const; + + const flat_hash_map& partial_status() const; + + static TensorDistAttribute get( + pir::IrContext* ctx, + ProcessMeshAttribute mesh, + const std::vector& dims_mapping, + const flat_hash_map& partial_status); + static TensorDistAttribute get( + pir::IrContext* ctx, + const phi::distributed::ProcessMesh& mesh, + const std::vector& dims_mapping, + const flat_hash_map& partial_status) { + return get(ctx, + ProcessMeshAttribute::get(ctx, mesh), + dims_mapping, + partial_status); + } +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ProcessMeshAttribute) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::TensorDistAttribute) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc new file mode 100644 index 0000000000000..5329c0086d742 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h" +#include "paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/fluid/pir/dialect/distributed/ir/type_storage.h" + +REGISTER_FILE_SYMBOLS(dist_dialect); +namespace paddle { +namespace dialect { + +DistDialect::DistDialect(pir::IrContext *context) + : pir::Dialect(name(), context, pir::TypeId::get()) { + initialize(); +} + +void DistDialect::initialize() { + RegisterAttributes(); + RegisterTypes(); +} + +void DistDialect::PrintType(pir::Type type, std::ostream &os) const { + if (auto dist_dense_tensor_type = type.dyn_cast()) { + // Todo: Design the dist dense tensor type print format. + os << dist_dense_tensor_type.dense_tensor_type(); + } else { + os << "error_type!"; + } +} + +void DistDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const { + if (auto process_mesh_attr = attr.dyn_cast()) { + os << process_mesh_attr.process_mesh(); + } else if (auto tensor_dist_attr = attr.dyn_cast()) { + // Todo: Design the tensor dist attr print format. + os << tensor_dist_attr.process_mesh(); + } else { + os << "error_attribute_type"; + } +} + +pir::OpPrintFn DistDialect::PrintOperation(pir::Operation *op) const { + return nullptr; +} + +} // namespace dialect +} // namespace paddle + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DistDialect) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h b/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h new file mode 100644 index 0000000000000..2a7420b0a495a --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h @@ -0,0 +1,41 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/core/dialect.h" + +namespace paddle { +namespace dialect { + +class DistDialect : public pir::Dialect { + public: + explicit DistDialect(pir::IrContext* context); + + static const char* name() { return "pd_dist"; } + + void PrintType(pir::Type type, std::ostream& os) const override; + + void PrintAttribute(pir::Attribute attr, std::ostream& os) const override; + + pir::OpPrintFn PrintOperation(pir::Operation* op) const override; // NOLINT + + private: + void initialize(); +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DistDialect) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc new file mode 100644 index 0000000000000..94a2d85fbcdd7 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/fluid/pir/dialect/distributed/ir/type_storage.h" + +namespace paddle { +namespace dialect { + +pir::DenseTensorType DistDenseTensorType::dense_tensor_type() const { + return storage()->dense_tensor_type; +} + +TensorDistAttribute DistDenseTensorType::tensor_dist_attr() const { + return storage()->tensor_dist_attr; +} + +const common::DDim& DistDenseTensorType::global_ddim() const { + return storage()->global_ddim; +} + +DistDenseTensorType DistDenseTensorType::get( + pir::IrContext* ctx, + pir::DenseTensorType dense_tensor_type, + TensorDistAttribute tensor_dist_attr, + const common::DDim& global_ddim) { + return Base::get(ctx, dense_tensor_type, tensor_dist_attr, global_ddim); +} +} // namespace dialect +} // namespace paddle + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DistDenseTensorType) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_type.h b/paddle/fluid/pir/dialect/distributed/ir/dist_type.h new file mode 100644 index 0000000000000..4aa08169440cc --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_type.h @@ -0,0 +1,61 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/core/type.h" + +namespace paddle { +namespace dialect { + +class DistDenseTensorTypeStorage; + +class DistDenseTensorType + : public pir::Type:: + TypeBase { + public: + using Base::Base; + + pir::DenseTensorType dense_tensor_type() const; + TensorDistAttribute tensor_dist_attr() const; + const common::DDim& global_ddim() const; + const common::DDim& local_ddim() const { return dense_tensor_type().dims(); } + Type dtype() const { return dense_tensor_type().dtype(); } + DataLayout data_layout() const { return dense_tensor_type().data_layout(); } + + const phi::distributed::ProcessMesh& process_mesh() const { + return tensor_dist_attr().process_mesh(); + } + const std::vector& dims_mapping() const { + return tensor_dist_attr().dims_mapping(); + } + std::set partial_dims() const { + return tensor_dist_attr().partial_dims(); + } + const flat_hash_map& partial_status() const { + return tensor_dist_attr().partial_status(); + } + + static DistDenseTensorType get(pir::IrContext* ctx, + pir::DenseTensorType dense_tensor_type, + TensorDistAttribute tensor_dist_attr, + const common::DDim& global_ddim); +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DistDenseTensorType) diff --git a/paddle/fluid/pir/dialect/distributed/ir/type_storage.h b/paddle/fluid/pir/dialect/distributed/ir/type_storage.h new file mode 100644 index 0000000000000..1f18573d3e162 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/type_storage.h @@ -0,0 +1,81 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/pir/include/core/builtin_type.h" + +namespace paddle { +namespace dialect { +/// +/// \brief Define Parametric TypeStorage for DistDenseTensorType. +/// +struct DistDenseTensorTypeStorage : public pir::TypeStorage { + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = + std::tuple; + + DistDenseTensorTypeStorage(pir::DenseTensorType dense_tensor_type, + TensorDistAttribute tensor_dist_attr, + const common::DDim& global_ddim) + : dense_tensor_type(dense_tensor_type), + tensor_dist_attr(tensor_dist_attr), + global_ddim(global_ddim) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static DistDenseTensorTypeStorage* Construct(ParamKey&& key) { + return new DistDenseTensorTypeStorage( + std::get<0>(key), std::get<1>(key), std::get<2>(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + auto dense_tensor_type_hash = std::hash()(std::get<0>(key)); + auto tensor_dist_attr_hash = std::hash()(std::get<1>(key)); + auto global_ddim_hash = std::hash()(std::get<2>(key)); + auto value = pir::detail::hash_combine(dense_tensor_type_hash, + tensor_dist_attr_hash); + return pir::detail::hash_combine(value, global_ddim_hash); + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return dense_tensor_type == std::get<0>(key) && + tensor_dist_attr == std::get<1>(key) && + global_ddim == std::get<2>(key); + } + + /// + /// \brief DistDenseTensorTypeStorage include three parameters: + /// dense_tensor_type, tensor_dist_attr and global_ddim; + /// + pir::DenseTensorType dense_tensor_type; + TensorDistAttribute tensor_dist_attr; + common::DDim global_ddim; +}; + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f1d53f3f88750..ffaef54bb9da9 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -223,6 +223,9 @@ PYBIND11_MAKE_OPAQUE(paddle::framework::FetchType); DECLARE_FILE_SYMBOLS(init_phi); DECLARE_FILE_SYMBOLS(kernel_dialect); +#ifdef PADDLE_WITH_DISTRIBUTE +DECLARE_FILE_SYMBOLS(dist_dialect); +#endif DECLARE_FILE_SYMBOLS(buffered_allocator); DECLARE_FILE_SYMBOLS(best_fit_allocator); DECLARE_FILE_SYMBOLS(aligned_allocator); diff --git a/paddle/pir/include/core/attribute.h b/paddle/pir/include/core/attribute.h index 9571440679b8c..2c1ca17656811 100644 --- a/paddle/pir/include/core/attribute.h +++ b/paddle/pir/include/core/attribute.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/pir/include/core/cast_utils.h" +#include "paddle/pir/include/core/storage_manager_support.h" #include "paddle/pir/include/core/type_id.h" constexpr char kAttrStopGradients[] = "stop_gradient"; @@ -87,6 +88,8 @@ class IR_API Attribute { return pir::dyn_cast(*this); } + std::size_t hash() const { return std::hash()(storage_); } + protected: const Storage *storage_{nullptr}; }; @@ -97,8 +100,6 @@ IR_API std::ostream &operator<<(std::ostream &os, Attribute attr); namespace std { template <> struct hash { - std::size_t operator()(const pir::Attribute &obj) const { - return std::hash()(obj); - } + std::size_t operator()(const pir::Attribute &obj) const { return obj.hash(); } }; } // namespace std diff --git a/paddle/pir/include/core/attribute_base.h b/paddle/pir/include/core/attribute_base.h index d6c75f2e5d8ce..0f459f23e9f99 100644 --- a/paddle/pir/include/core/attribute_base.h +++ b/paddle/pir/include/core/attribute_base.h @@ -16,8 +16,8 @@ #include "paddle/pir/include/core/ir_context.h" #include "paddle/pir/include/core/storage_manager.h" +#include "paddle/pir/include/core/storage_manager_support.h" #include "paddle/pir/include/core/type_id.h" - namespace pir { class Dialect; @@ -239,6 +239,16 @@ struct IR_API AttributeManager { } }; +template +using AttrBase = detail::StorageHelperBase; + /// /// \brief Add some necessary functions to the custom Attribute class. /// diff --git a/paddle/pir/include/core/storage_manager.h b/paddle/pir/include/core/storage_manager.h index 8cacc3bd38bd0..7024e580e4a1f 100644 --- a/paddle/pir/include/core/storage_manager.h +++ b/paddle/pir/include/core/storage_manager.h @@ -74,7 +74,7 @@ class IR_API StorageManager { return static_cast(*existing) == param; }; auto constructor = [&]() { - auto *storage = Storage::Construct(param); + auto *storage = Storage::Construct(std::move(param)); if (init_func) init_func(storage); return storage; }; diff --git a/paddle/pir/include/core/storage_manager_support.h b/paddle/pir/include/core/storage_manager_support.h index 7d4d540382dcd..b729a4480ac35 100644 --- a/paddle/pir/include/core/storage_manager_support.h +++ b/paddle/pir/include/core/storage_manager_support.h @@ -18,8 +18,6 @@ #include "paddle/pir/include/core/interface_support.h" #include "paddle/pir/include/core/ir_context.h" -#include "paddle/pir/include/core/type.h" -#include "paddle/pir/include/core/type_base.h" #include "paddle/pir/include/core/type_id.h" namespace pir { @@ -68,7 +66,7 @@ class StorageHelperBase : public BaseT { typename Filter>::Type; static ConcreteT dyn_cast_impl(BaseT type) { - if (type && type.abstract_type().type_id() == TypeId::get()) { + if (type && type.type_id() == TypeId::get()) { return ConcreteT(type.storage()); } return ConcreteT(nullptr); @@ -107,8 +105,8 @@ class StorageHelperBase : public BaseT { /// \brief Get or create a new ConcreteT instance within the ctx. /// template - static ConcreteT get(pir::IrContext *ctx, Args... args) { - return ManagerT::template get(ctx, args...); + static ConcreteT get(pir::IrContext *ctx, Args &&...args) { + return ManagerT::template get(ctx, std::forward(args)...); } /// diff --git a/paddle/pir/include/core/type.h b/paddle/pir/include/core/type.h index 569b356135b18..fcfe0a77a8ac5 100644 --- a/paddle/pir/include/core/type.h +++ b/paddle/pir/include/core/type.h @@ -18,6 +18,7 @@ #include "paddle/pir/include/core/cast_utils.h" #include "paddle/pir/include/core/storage_manager_support.h" +#include "paddle/pir/include/core/type_base.h" #include "paddle/pir/include/core/type_id.h" namespace pir { @@ -42,7 +43,6 @@ class IR_API Type { StorageType, TypeManager, TraitOrInterface...>; - using Storage = TypeStorage; using AbstractT = AbstractType; @@ -125,6 +125,8 @@ class IR_API Type { bool IsIntOrIndex() const; bool IsIndex() const; + std::size_t hash() const { return std::hash()(storage_); } + protected: const Storage *storage_{nullptr}; @@ -184,8 +186,6 @@ namespace std { /// template <> struct hash { - std::size_t operator()(const pir::Type &obj) const { - return std::hash()(obj); - } + std::size_t operator()(const pir::Type &obj) const { return obj.hash(); } }; } // namespace std diff --git a/test/cpp/pir/CMakeLists.txt b/test/cpp/pir/CMakeLists.txt index 420ffa8b6dc5a..e7de653656897 100644 --- a/test/cpp/pir/CMakeLists.txt +++ b/test/cpp/pir/CMakeLists.txt @@ -7,3 +7,4 @@ add_subdirectory(cinn) add_subdirectory(control_flow_dialect) add_subdirectory(shape_dialect) add_subdirectory(sub_graph) +add_subdirectory(distributed) diff --git a/test/cpp/pir/distributed/CMakeLists.txt b/test/cpp/pir/distributed/CMakeLists.txt new file mode 100644 index 0000000000000..0483dbe1fdac0 --- /dev/null +++ b/test/cpp/pir/distributed/CMakeLists.txt @@ -0,0 +1,3 @@ +if(WITH_DISTRIBUTE) + paddle_test(dist_dialect_test SRCS dist_dialect_test.cc) +endif() diff --git a/test/cpp/pir/distributed/dist_dialect_test.cc b/test/cpp/pir/distributed/dist_dialect_test.cc new file mode 100644 index 0000000000000..01dcb2f1010d5 --- /dev/null +++ b/test/cpp/pir/distributed/dist_dialect_test.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/pir/include/core/builtin_type.h" + +using namespace paddle::dialect; // NOLINT + +TEST(process_mesh_test, base) { + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + std::vector mesh_shape = {2, 2}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x", "y"}; + std::vector dim_names_2 = {"x", "s"}; + phi::distributed::ProcessMesh process_mesh( + mesh_shape, process_ids, dim_names); + + // construct a ProcessMeshAttribute. + auto mesh_attr = + ProcessMeshAttribute::get(ctx, mesh_shape, process_ids, dim_names); + auto mesh_attr_1 = ProcessMeshAttribute::get(ctx, process_mesh); + auto mesh_attr_2 = + ProcessMeshAttribute::get(ctx, mesh_shape, process_ids, dim_names_2); + EXPECT_EQ(mesh_attr, mesh_attr_1); + EXPECT_NE(mesh_attr, mesh_attr_2); + + // test member function. + EXPECT_EQ(mesh_attr.process_mesh(), process_mesh); + EXPECT_EQ(mesh_attr.shape(), mesh_shape); + EXPECT_EQ(mesh_attr.process_ids(), process_ids); + EXPECT_EQ(mesh_attr.dim_names(), dim_names); + EXPECT_EQ(mesh_attr.size(), 4); + EXPECT_EQ(mesh_attr.ndim(), 2); + EXPECT_EQ(mesh_attr.dim_size(0), 2); + EXPECT_EQ(mesh_attr.dim_size("y"), 2); + EXPECT_FALSE(mesh_attr.empty()); + EXPECT_TRUE(mesh_attr.contains(3)); + EXPECT_EQ(mesh_attr.hash(), process_mesh.hash()); + EXPECT_EQ(mesh_attr.to_string(), process_mesh.to_string()); +} +TEST(tensor_dist_attr_test, base) { + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + + std::vector mesh_shape = {2, 3}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x", "y"}; + phi::distributed::ProcessMesh process_mesh( + mesh_shape, process_ids, dim_names); + std::vector dims_mapping = {0, -1}; + paddle::flat_hash_map partial_status, + partial_status_1{{1, phi::ReduceType::kRedSum}}; + + auto mesh_attr = + ProcessMeshAttribute::get(ctx, mesh_shape, process_ids, dim_names); + + // construct a TensorDistAttribute. + auto tensor_dist_attr = + TensorDistAttribute::get(ctx, mesh_attr, dims_mapping, partial_status); + auto tensor_dist_attr_1 = + TensorDistAttribute::get(ctx, process_mesh, dims_mapping, partial_status); + auto tensor_dist_attr_2 = TensorDistAttribute::get( + ctx, process_mesh, dims_mapping, partial_status_1); + EXPECT_EQ(tensor_dist_attr, tensor_dist_attr_1); + EXPECT_NE(tensor_dist_attr, tensor_dist_attr_2); + + // test member function. + EXPECT_EQ(tensor_dist_attr.mesh_attr(), mesh_attr); + EXPECT_EQ(tensor_dist_attr.process_mesh(), process_mesh); + EXPECT_EQ(tensor_dist_attr.dims_mapping(), dims_mapping); + EXPECT_EQ(tensor_dist_attr.partial_status(), partial_status); +} + +TEST(dist_dense_tensor_type_test, base) { + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + std::vector mesh_shape = {2, 3}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x", "y"}; + phi::distributed::ProcessMesh process_mesh( + mesh_shape, process_ids, dim_names); + auto mesh_attr = ProcessMeshAttribute::get(ctx, process_mesh); + + std::vector dims_mapping = {0, -1}; + paddle::flat_hash_map partial_status{ + {1, phi::ReduceType::kRedSum}}; + // construct a TensorDistAttribute. + auto tensor_dist_attr = + TensorDistAttribute::get(ctx, mesh_attr, dims_mapping, partial_status); + + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + common::DDim dims = {2, 2}; + common::DataLayout data_layout = common::DataLayout::NCHW; + pir::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + pir::DenseTensorType dense_tensor_type = pir::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset); + + auto dist_densor_type = + DistDenseTensorType::get(ctx, dense_tensor_type, tensor_dist_attr, dims); + + EXPECT_EQ(dist_densor_type.process_mesh(), process_mesh); + EXPECT_EQ(dist_densor_type.dims_mapping(), dims_mapping); + EXPECT_EQ(dist_densor_type.partial_status(), partial_status); + EXPECT_EQ(dist_densor_type.dtype().isa(), true); + EXPECT_EQ(dist_densor_type.global_ddim(), dims); + EXPECT_EQ(dist_densor_type.data_layout(), data_layout); + EXPECT_EQ(dist_densor_type.local_ddim(), dims); +}