Skip to content

Commit

Permalink
[IR] add region data structure. (#54185)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored May 30, 2023
1 parent 9efa5af commit 88e4362
Show file tree
Hide file tree
Showing 17 changed files with 308 additions and 92 deletions.
10 changes: 5 additions & 5 deletions paddle/fluid/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ inline ir::Operation* InsertSliceOperationForTarget(
defining_info.value.type().dyn_cast<ir::VectorType>();
ir::Operation* operation =
ir::Operation::create({defining_info.value},
{src_vec_type[defining_info.idx_in_vector]},
op_attribute_map,
{src_vec_type[defining_info.idx_in_vector]},
op_info);
program->InsertOp(operation);
ir::OpResult target_op_result = operation->GetResultByIndex(0);
Expand All @@ -136,7 +136,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
}
ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec);
ir::Operation* operation =
ir::Operation::create(src_values, {target_vec_type}, {}, op_info);
ir::Operation::create(src_values, {}, {target_vec_type}, op_info);
program->InsertOp(operation);
return operation;
}
Expand Down Expand Up @@ -281,7 +281,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc);
auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation =
ir::Operation::create(op_inputs, op_output_types, {}, op_info);
ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);

Expand All @@ -299,7 +299,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
std::tie(op_output_types, arg_to_idx) = GenerateOperationOutput(ctx, op_desc);
auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation =
ir::Operation::create(op_inputs, op_output_types, {}, op_info);
ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);

Expand All @@ -315,7 +315,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
OpOutputTypeList op_output_types = {};
auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation =
ir::Operation::create(op_inputs, op_output_types, {}, op_info);
ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation);

return operation;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
};
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::create(
{}, {translated_var_type}, op_attribute_map, op_info);
{}, op_attribute_map, {translated_var_type}, op_info);
program->InsertOp(operation);
param_map[var->Name()] =
VariableDefiningInfo(operation->GetResultByIndex(0));
Expand Down
14 changes: 14 additions & 0 deletions paddle/ir/core/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@

namespace ir {
Block::~Block() { clear(); }
void Block::push_back(Operation *op) {
op->set_parent(this);
ops_.push_back(op);
}

void Block::push_front(Operation *op) {
op->set_parent(this);
ops_.push_front(op);
}

Block::iterator Block::insert(const_iterator iterator, Operation *op) {
op->set_parent(this);
return ops_.insert(iterator, op);
}

void Block::clear() {
while (!empty()) {
Expand Down
24 changes: 15 additions & 9 deletions paddle/ir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,23 @@

#pragma once

#include <cstddef>
#include <list>
#include "paddle/ir/core/operation.h"

namespace ir {
class Region;

class Block {
public:
using iterator = std::list<Operation *>::iterator;
using reverse_iterator = std::list<Operation *>::reverse_iterator;
using const_iterator = std::list<Operation *>::const_iterator;

Block() = default;
~Block();

Region *parent() const { return parent_; }
bool empty() const { return ops_.empty(); }
size_t size() const { return ops_.size(); }

Expand All @@ -34,21 +39,22 @@ class Block {
reverse_iterator rbegin() { return ops_.rbegin(); }
reverse_iterator rend() { return ops_.rend(); }

Operation *back() { return ops_.back(); }
Operation *front() { return ops_.front(); }
void push_back(Operation *op) { ops_.push_back(op); }
void push_front(Operation *op) { ops_.push_front(op); }
std::list<Operation *>::iterator insert(
std::list<Operation *>::const_iterator iterator, Operation *op) {
return ops_.insert(iterator, op);
}
Operation *back() const { return ops_.back(); }
Operation *front() const { return ops_.front(); }
void push_back(Operation *op);
void push_front(Operation *op);
iterator insert(const_iterator iterator, Operation *op);
void clear();

private:
Block(Block &) = delete;
void operator=(Block &) = delete;
Block &operator=(const Block &) = delete;

friend class Region;
void set_parent(Region *parent) { parent_ = parent; }

private:
Region *parent_; // not owned
std::list<Operation *> ops_; // owned
};
} // namespace ir
10 changes: 5 additions & 5 deletions paddle/ir/core/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/region.h"

namespace ir {
Operation *Builder::insert(Operation *op) {
Expand All @@ -25,17 +26,16 @@ Operation *Builder::insert(Operation *op) {
}

/// Create an operation given the fields represented as an OperationState.
Operation *Builder::create(const OperationArgument &argument) {
return insert(Operation::create(argument));
Operation *Builder::create(OperationArgument &&argument) {
return insert(Operation::create(std::move(argument)));
}

/// Creates an operation with the given fields.
Operation *Builder::create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute,
const std::vector<ir::Type> &output_types,
ir::OpInfo op_info) {
OperationArgument argument(op_info, inputs, output_types, attribute);
return create(argument);
return create(OperationArgument(inputs, attribute, output_types, op_info));
}

} // namespace ir
6 changes: 3 additions & 3 deletions paddle/ir/core/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ class Builder {
Operation *insert(Operation *op);

/// Creates an operation given the fields represented as an OperationState.
Operation *create(const OperationArgument &argument);
Operation *create(OperationArgument &&argument);

/// Creates an operation with the given fields.
Operation *create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute,
const std::vector<ir::Type> &output_types,
ir::OpInfo op_info);

/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpTy create(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::build(*this, argument, std::forward<Args>(args)...);
Operation *op = create(argument);
Operation *op = create(std::move(argument));
return op->dyn_cast<OpTy>();
}

Expand Down
1 change: 1 addition & 0 deletions paddle/ir/core/builtin_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/ir/core/op_base.h"

namespace ir {

///
/// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute,
/// StrAttribute})
Expand Down
68 changes: 49 additions & 19 deletions paddle/ir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,31 @@
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/utils.h"

namespace ir {
Operation *Operation::create(const OperationArgument &argument) {
return create(argument.inputs_,
argument.output_types_,
argument.attribute_,
argument.info_);
Operation *Operation::create(OperationArgument &&argument) {
Operation *op = create(argument.inputs,
argument.attribute,
argument.output_types,
argument.info,
argument.regions.size());

for (size_t index = 0; index < argument.regions.size(); ++index) {
op->GetRegion(index).TakeBody(std::move(*argument.regions[index]));
}
return op;
}

// Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand.
Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute,
ir::OpInfo op_info) {
const std::vector<ir::Type> &output_types,
ir::OpInfo op_info,
size_t num_regions) {
// 0. Verify
if (op_info) {
op_info.verify(inputs, output_types, attribute);
Expand All @@ -50,7 +58,9 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
: sizeof(detail::OpInlineResultImpl) * num_results;
size_t operand_mem_size = sizeof(detail::OpOperandImpl) * num_operands;
size_t op_mem_size = sizeof(Operation);
size_t base_size = result_mem_size + op_mem_size + operand_mem_size;
size_t region_mem_size = num_regions * sizeof(Region);
size_t base_size =
result_mem_size + op_mem_size + operand_mem_size + region_mem_size;
// 2. Malloc memory.
char *base_ptr = reinterpret_cast<char *>(aligned_malloc(base_size, 8));
// 3.1. Construct OpResults.
Expand All @@ -65,8 +75,8 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
}
}
// 3.2. Construct Operation.
Operation *op =
new (base_ptr) Operation(num_results, num_operands, attribute, op_info);
Operation *op = new (base_ptr)
Operation(attribute, op_info, num_results, num_operands, num_regions);
base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
Expand All @@ -76,13 +86,27 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
base_ptr += sizeof(detail::OpOperandImpl);
}

// 3.4. Construct Regions
if (num_regions > 0) {
op->regions_ = reinterpret_cast<Region *>(base_ptr);
for (size_t idx = 0; idx < num_regions; idx++) {
new (base_ptr) Region(op);
base_ptr += sizeof(Region);
}
}
return op;
}

// Call destructors for OpResults, Operation, and OpOperands in sequence, and
// finally free memory.
void Operation::destroy() {
// Deconstruct Regions.
if (num_regions_ > 0) {
for (size_t idx = 0; idx < num_regions_; idx++) {
regions_[idx].~Region();
}
}

// 1. Get aligned_ptr by result_num.
uint32_t max_inline_result_num =
detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
Expand Down Expand Up @@ -136,15 +160,16 @@ void Operation::destroy() {

IrContext *Operation::ir_context() const { return op_info_.ir_context(); }

Operation::Operation(uint32_t num_results,
Operation::Operation(const AttributeMap &attribute,
ir::OpInfo op_info,
uint32_t num_results,
uint32_t num_operands,
const AttributeMap &attribute,
ir::OpInfo op_info) {
num_results_ = num_results;
num_operands_ = num_operands;
attribute_ = attribute;
op_info_ = op_info;
}
uint32_t num_regions)
: attribute_(attribute),
op_info_(op_info),
num_results_(num_results),
num_operands_(num_operands),
num_regions_(num_regions) {}

ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
if (index >= num_results_) {
Expand Down Expand Up @@ -198,4 +223,9 @@ std::string Operation::print() {

std::string Operation::op_name() const { return op_info_.name(); }

Region &Operation::GetRegion(unsigned index) {
assert(index < num_regions_ && "invalid region index");
return regions_[index];
}

} // namespace ir
Loading

0 comments on commit 88e4362

Please sign in to comment.