diff --git a/paddle/fluid/pir/transforms/transform_general_functions.cc b/paddle/fluid/pir/transforms/transform_general_functions.cc index 67edc63a63681..1c3c576a6f278 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.cc +++ b/paddle/fluid/pir/transforms/transform_general_functions.cc @@ -84,8 +84,7 @@ std::vector> GetUseOpsForOutput(Operation* op, auto result = op->result(index); std::vector> use_ops; for (auto it = result.use_begin(); it != result.use_end(); ++it) { - use_ops.push_back( - std::make_pair(it->owner(), it->owner()->operand_index(*it))); + use_ops.push_back(std::make_pair(it->owner(), it->index())); } return use_ops; } diff --git a/paddle/pir/core/op_operand.cc b/paddle/pir/core/op_operand.cc index 74e5dced1fc63..974420aab4dd7 100644 --- a/paddle/pir/core/op_operand.cc +++ b/paddle/pir/core/op_operand.cc @@ -53,6 +53,11 @@ Operation *OpOperand::owner() const { return impl_->owner(); } +uint32_t OpOperand::index() const { + CHECK_OPOPEREND_NULL_IMPL(index); + return impl_->index(); +} + void OpOperand::RemoveFromUdChain() { CHECK_OPOPEREND_NULL_IMPL(RemoveFromUdChain); return impl_->RemoveFromUdChain(); diff --git a/paddle/pir/core/op_operand.h b/paddle/pir/core/op_operand.h index 91636ea9ed8ba..bde1118f7b7c9 100644 --- a/paddle/pir/core/op_operand.h +++ b/paddle/pir/core/op_operand.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once - +#include #include "paddle/pir/core/dll_decl.h" namespace pir { @@ -57,6 +57,8 @@ class IR_API OpOperand { Operation *owner() const; + uint32_t index() const; + void RemoveFromUdChain(); friend Operation; diff --git a/paddle/pir/core/op_operand_impl.cc b/paddle/pir/core/op_operand_impl.cc index 44a3a5f28bb6e..bfc84e7f8beb6 100644 --- a/paddle/pir/core/op_operand_impl.cc +++ b/paddle/pir/core/op_operand_impl.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/pir/core/op_operand_impl.h" +#include "paddle/pir/core/operation.h" #include "paddle/pir/core/value_impl.h" namespace pir { @@ -40,6 +41,13 @@ OpOperandImpl::OpOperandImpl(pir::Value source, pir::Operation *owner) InsertToUdChain(); } +uint32_t OpOperandImpl::index() const { + const char *start = + reinterpret_cast(owner_) + sizeof(Operation); + const char *end = reinterpret_cast(this); + return (end - start) / sizeof(OpOperandImpl); +} + void OpOperandImpl::InsertToUdChain() { prev_use_addr_ = source_.impl()->first_use_addr(); next_use_ = source_.impl()->first_use(); diff --git a/paddle/pir/core/op_operand_impl.h b/paddle/pir/core/op_operand_impl.h index f1bc9d23c0928..585b34e7c5d91 100644 --- a/paddle/pir/core/op_operand_impl.h +++ b/paddle/pir/core/op_operand_impl.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include "paddle/pir/core/value.h" namespace pir { @@ -33,6 +34,8 @@ class OpOperandImpl { void set_source(Value value); + uint32_t index() const; + /// Remove this op_operand from the current use list. void RemoveFromUdChain(); diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index e09076e9a1256..fc670d4e9e44e 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -245,17 +245,6 @@ std::vector Operation::operands_source() const { return res; } -int32_t Operation::operand_index(const OpOperand &op_operand) const { - int32_t res = -1; - for (uint32_t i = 0; i < num_operands(); ++i) { - if (op_operand == operand(i)) { - res = i; - break; - } - } - return res; -} - /// /// \brief op successor related public interfaces /// diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index 7a30816a3b485..7f3c9e28932cd 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -100,7 +100,6 @@ class IR_API alignas(8) Operation final std::vector operands(); Value operand_source(uint32_t index) const; std::vector operands_source() const; - int32_t operand_index(const OpOperand &op_operand) const; /// /// \brief op successor related public interfaces