Skip to content

Commit

Permalink
[IR] Support TypeAttribute. (#54984)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo authored Jul 2, 2023
1 parent 89feae0 commit cc7d1f3
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 7 deletions.
3 changes: 3 additions & 0 deletions paddle/ir/core/builtin_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ std::vector<Attribute> ArrayAttribute::data() const {

void* PointerAttribute::data() const { return storage()->GetAsKey(); }

Type TypeAttribute::data() const { return storage()->GetAsKey(); }

} // namespace ir

IR_DEFINE_EXPLICIT_TYPE_ID(ir::StrAttribute)
Expand All @@ -45,3 +47,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Attribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Attribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ArrayAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::PointerAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::TypeAttribute)
10 changes: 10 additions & 0 deletions paddle/ir/core/builtin_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ class IR_API PointerAttribute : public Attribute {
void* data() const;
};

class IR_API TypeAttribute : public Attribute {
public:
using Attribute::Attribute;

DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage);

Type data() const;
};

} // namespace ir

IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::StrAttribute)
Expand All @@ -113,3 +122,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Attribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Attribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ArrayAttribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::PointerAttribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::TypeAttribute)
22 changes: 22 additions & 0 deletions paddle/ir/core/builtin_attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/utils.h"

namespace ir {
Expand Down Expand Up @@ -131,4 +132,25 @@ struct ArrayAttributeStorage : public AttributeStorage {
size_t length_ = 0;
};

struct TypeAttributeStorage : public AttributeStorage {
using ParamKey = Type;

explicit TypeAttributeStorage(const ParamKey &key) : value_(key) {}

static TypeAttributeStorage *Construct(ParamKey key) {
return new TypeAttributeStorage(key);
}

static std::size_t HashValue(const ParamKey &key) {
return std::hash<Type>()(key);
}

bool operator==(const ParamKey &key) const { return value_ == key; }

ParamKey GetAsKey() const { return value_; }

private:
Type value_;
};

} // namespace ir
3 changes: 2 additions & 1 deletion paddle/ir/core/builtin_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ void BuiltinDialect::initialize() {
PointerAttribute,
Int32Attribute,
Int64Attribute,
ArrayAttribute>();
ArrayAttribute,
TypeAttribute>();

RegisterOps<ModuleOp,
GetParameterOp,
Expand Down
2 changes: 2 additions & 0 deletions paddle/ir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
[this](Attribute v) { this->PrintAttribute(v); },
[this]() { this->os << ","; });
os << "]";
} else if (auto type = attr.dyn_cast<TypeAttribute>()) {
os << type.data();
} else {
auto& dialect = attr.dialect();
dialect.PrintAttribute(attr, os);
Expand Down
9 changes: 5 additions & 4 deletions paddle/ir/transforms/dce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ namespace {
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass.
// Now just a naive implementation.
class DCEPass : public ir::Pass {
class DcePass : public ir::Pass {
public:
DCEPass() : ir::Pass("DCEPass", 0) {}
DcePass() : ir::Pass("DcePass", 0) {}

void Run(ir::Operation *op) override {
auto module_op = op->dyn_cast<ir::ModuleOp>();
IR_ENFORCE(module_op, "DCEPass should run on module op.");
IR_ENFORCE(module_op, "DcePass should run on module op.");
auto *block = module_op.block();
std::vector<ir::Operation> erased_op;
for (auto it = block->begin(); it != block->end(); ++it) {
Expand All @@ -39,6 +39,7 @@ class DCEPass : public ir::Pass {
for (uint32_t i = 0; i < (*it)->num_results(); ++i) {
use_empty &= (*it)->result(i).use_empty();
}
// TODO(wilber): Support Terminator trait.
if (use_empty && (*it)->name() != "pd.fetch") {
erased_op.push_back(**it);
}
Expand All @@ -56,6 +57,6 @@ class DCEPass : public ir::Pass {

namespace ir {

std::unique_ptr<Pass> CreateDCEPass() { return std::make_unique<DCEPass>(); }
std::unique_ptr<Pass> CreateDcePass() { return std::make_unique<DcePass>(); }

} // namespace ir
2 changes: 1 addition & 1 deletion paddle/ir/transforms/dce.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
namespace ir {
class Pass;

IR_API std::unique_ptr<Pass> CreateDCEPass();
IR_API std::unique_ptr<Pass> CreateDcePass();

} // namespace ir
7 changes: 7 additions & 0 deletions test/cpp/ir/core/ir_attribute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"

Expand Down Expand Up @@ -63,4 +64,10 @@ TEST(attribute_test, built_in_attribute) {
string_attr_1.dyn_cast<ir::StrAttribute>();
EXPECT_EQ(string_attr_cast_1.isa<ir::StrAttribute>(), true);
EXPECT_EQ(string_attr_cast_1.size() == 8, 1);

ir::Int32Type i32_type = ir::Int32Type::get(ctx);
ir::Attribute type_attr = ir::TypeAttribute::get(ctx, i32_type);
EXPECT_TRUE(type_attr.isa<ir::TypeAttribute>());
EXPECT_EQ(type_attr.dyn_cast<ir::TypeAttribute>().data().type_id(),
i32_type.type_id());
}
2 changes: 1 addition & 1 deletion test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ TEST(pattern_rewrite, Patterns) {

ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
pm.AddPass(ir::CreateDCEPass());
pm.AddPass(ir::CreateDcePass());
program.Print(std::cout);
std::cout << std::endl;
pm.Run(&program);
Expand Down

0 comments on commit cc7d1f3

Please sign in to comment.