From f587138e0e8014819fd5cd60eda3c68483e01aa7 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 7 Dec 2023 15:28:26 +0800 Subject: [PATCH 1/4] [PIR]Move Operation::operand_index() into Operand::index() --- .../pir/transforms/transform_general_functions.cc | 3 +-- paddle/pir/core/op_operand.cc | 5 +++++ paddle/pir/core/op_operand.h | 2 ++ paddle/pir/core/op_operand_impl.cc | 9 +++++++++ paddle/pir/core/op_operand_impl.h | 2 ++ paddle/pir/core/operation.cc | 11 ----------- paddle/pir/core/operation.h | 1 - 7 files changed, 19 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/pir/transforms/transform_general_functions.cc b/paddle/fluid/pir/transforms/transform_general_functions.cc index 67edc63a63681..1c3c576a6f278 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.cc +++ b/paddle/fluid/pir/transforms/transform_general_functions.cc @@ -84,8 +84,7 @@ std::vector> GetUseOpsForOutput(Operation* op, auto result = op->result(index); std::vector> use_ops; for (auto it = result.use_begin(); it != result.use_end(); ++it) { - use_ops.push_back( - std::make_pair(it->owner(), it->owner()->operand_index(*it))); + use_ops.push_back(std::make_pair(it->owner(), it->index())); } return use_ops; } diff --git a/paddle/pir/core/op_operand.cc b/paddle/pir/core/op_operand.cc index 74e5dced1fc63..2b22c277d7d9c 100644 --- a/paddle/pir/core/op_operand.cc +++ b/paddle/pir/core/op_operand.cc @@ -53,6 +53,11 @@ Operation *OpOperand::owner() const { return impl_->owner(); } +int32_t OpOperand::index() const { + CHECK_OPOPEREND_NULL_IMPL(index); + return impl_->index(); +} + void OpOperand::RemoveFromUdChain() { CHECK_OPOPEREND_NULL_IMPL(RemoveFromUdChain); return impl_->RemoveFromUdChain(); diff --git a/paddle/pir/core/op_operand.h b/paddle/pir/core/op_operand.h index 91636ea9ed8ba..88d5c40bbf3f3 100644 --- a/paddle/pir/core/op_operand.h +++ b/paddle/pir/core/op_operand.h @@ -57,6 +57,8 @@ class IR_API OpOperand { Operation *owner() const; + int32_t index() const; + void RemoveFromUdChain(); friend Operation; diff --git a/paddle/pir/core/op_operand_impl.cc b/paddle/pir/core/op_operand_impl.cc index 44a3a5f28bb6e..0fd99aef1e2e1 100644 --- a/paddle/pir/core/op_operand_impl.cc +++ b/paddle/pir/core/op_operand_impl.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/pir/core/op_operand_impl.h" +#include "paddle/pir/core/operation.h" #include "paddle/pir/core/value_impl.h" namespace pir { @@ -40,6 +41,14 @@ OpOperandImpl::OpOperandImpl(pir::Value source, pir::Operation *owner) InsertToUdChain(); } +int32_t OpOperandImpl::index() const { + const char *start = + reinterpret_cast(owner_) + sizeof(Operation); + const char *end = reinterpret_cast(this); + int32_t index = (start - end) / sizeof(OpOperandImpl); + return index; +} + void OpOperandImpl::InsertToUdChain() { prev_use_addr_ = source_.impl()->first_use_addr(); next_use_ = source_.impl()->first_use(); diff --git a/paddle/pir/core/op_operand_impl.h b/paddle/pir/core/op_operand_impl.h index f1bc9d23c0928..4806ddd851922 100644 --- a/paddle/pir/core/op_operand_impl.h +++ b/paddle/pir/core/op_operand_impl.h @@ -33,6 +33,8 @@ class OpOperandImpl { void set_source(Value value); + int32_t index() const; + /// Remove this op_operand from the current use list. void RemoveFromUdChain(); diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index e09076e9a1256..fc670d4e9e44e 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -245,17 +245,6 @@ std::vector Operation::operands_source() const { return res; } -int32_t Operation::operand_index(const OpOperand &op_operand) const { - int32_t res = -1; - for (uint32_t i = 0; i < num_operands(); ++i) { - if (op_operand == operand(i)) { - res = i; - break; - } - } - return res; -} - /// /// \brief op successor related public interfaces /// diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index 7a30816a3b485..7f3c9e28932cd 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -100,7 +100,6 @@ class IR_API alignas(8) Operation final std::vector operands(); Value operand_source(uint32_t index) const; std::vector operands_source() const; - int32_t operand_index(const OpOperand &op_operand) const; /// /// \brief op successor related public interfaces From 79cc59f21d4e4263da444c4558996732c1b87124 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 7 Dec 2023 15:45:55 +0800 Subject: [PATCH 2/4] fix order --- paddle/pir/core/op_operand_impl.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/pir/core/op_operand_impl.cc b/paddle/pir/core/op_operand_impl.cc index 0fd99aef1e2e1..805a02d0895e4 100644 --- a/paddle/pir/core/op_operand_impl.cc +++ b/paddle/pir/core/op_operand_impl.cc @@ -45,7 +45,7 @@ int32_t OpOperandImpl::index() const { const char *start = reinterpret_cast(owner_) + sizeof(Operation); const char *end = reinterpret_cast(this); - int32_t index = (start - end) / sizeof(OpOperandImpl); + int32_t index = (end - start) / sizeof(OpOperandImpl); return index; } From 47f64fa7ce12ec717465cc7bc7b5370d34650365 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 7 Dec 2023 15:49:34 +0800 Subject: [PATCH 3/4] add IR_ENFORCE --- paddle/pir/core/op_operand_impl.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/pir/core/op_operand_impl.cc b/paddle/pir/core/op_operand_impl.cc index 805a02d0895e4..887d08a78cab4 100644 --- a/paddle/pir/core/op_operand_impl.cc +++ b/paddle/pir/core/op_operand_impl.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/pir/core/op_operand_impl.h" +#include "paddle/common/enforce.h" #include "paddle/pir/core/operation.h" #include "paddle/pir/core/value_impl.h" @@ -46,6 +47,7 @@ int32_t OpOperandImpl::index() const { reinterpret_cast(owner_) + sizeof(Operation); const char *end = reinterpret_cast(this); int32_t index = (end - start) / sizeof(OpOperandImpl); + IR_ENFORCE(index >= 0, "Required index >= 0, but received index = %d", index); return index; } From d84595ce53e3e5ea208066f990b34ac91ab195a2 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 7 Dec 2023 09:16:32 +0000 Subject: [PATCH 4/4] fix header --- paddle/pir/core/op_operand.cc | 2 +- paddle/pir/core/op_operand.h | 4 ++-- paddle/pir/core/op_operand_impl.cc | 7 ++----- paddle/pir/core/op_operand_impl.h | 3 ++- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/paddle/pir/core/op_operand.cc b/paddle/pir/core/op_operand.cc index 2b22c277d7d9c..974420aab4dd7 100644 --- a/paddle/pir/core/op_operand.cc +++ b/paddle/pir/core/op_operand.cc @@ -53,7 +53,7 @@ Operation *OpOperand::owner() const { return impl_->owner(); } -int32_t OpOperand::index() const { +uint32_t OpOperand::index() const { CHECK_OPOPEREND_NULL_IMPL(index); return impl_->index(); } diff --git a/paddle/pir/core/op_operand.h b/paddle/pir/core/op_operand.h index 88d5c40bbf3f3..bde1118f7b7c9 100644 --- a/paddle/pir/core/op_operand.h +++ b/paddle/pir/core/op_operand.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once - +#include #include "paddle/pir/core/dll_decl.h" namespace pir { @@ -57,7 +57,7 @@ class IR_API OpOperand { Operation *owner() const; - int32_t index() const; + uint32_t index() const; void RemoveFromUdChain(); diff --git a/paddle/pir/core/op_operand_impl.cc b/paddle/pir/core/op_operand_impl.cc index 887d08a78cab4..bfc84e7f8beb6 100644 --- a/paddle/pir/core/op_operand_impl.cc +++ b/paddle/pir/core/op_operand_impl.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/pir/core/op_operand_impl.h" -#include "paddle/common/enforce.h" #include "paddle/pir/core/operation.h" #include "paddle/pir/core/value_impl.h" @@ -42,13 +41,11 @@ OpOperandImpl::OpOperandImpl(pir::Value source, pir::Operation *owner) InsertToUdChain(); } -int32_t OpOperandImpl::index() const { +uint32_t OpOperandImpl::index() const { const char *start = reinterpret_cast(owner_) + sizeof(Operation); const char *end = reinterpret_cast(this); - int32_t index = (end - start) / sizeof(OpOperandImpl); - IR_ENFORCE(index >= 0, "Required index >= 0, but received index = %d", index); - return index; + return (end - start) / sizeof(OpOperandImpl); } void OpOperandImpl::InsertToUdChain() { diff --git a/paddle/pir/core/op_operand_impl.h b/paddle/pir/core/op_operand_impl.h index 4806ddd851922..585b34e7c5d91 100644 --- a/paddle/pir/core/op_operand_impl.h +++ b/paddle/pir/core/op_operand_impl.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include "paddle/pir/core/value.h" namespace pir { @@ -33,7 +34,7 @@ class OpOperandImpl { void set_source(Value value); - int32_t index() const; + uint32_t index() const; /// Remove this op_operand from the current use list. void RemoveFromUdChain();