Skip to content

Commit

Permalink
[PIR] Support Region Clone in Operation::Clone (#60590)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjiyi authored Jan 9, 2024
1 parent fbb5801 commit 47ecd81
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 80 deletions.
89 changes: 24 additions & 65 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,47 +898,6 @@ void BindInsertionPoint(pybind11::module *m) {
InsertionPoint class represents the insertion point in the Builder.)DOC");
}

Operation *BuildOpFrom(
Operation *to_copy_op,
std::unordered_map<pir::Value, pir::Value> &value_map) { // NOLINT
pir::OperationArgument to_create_argument(to_copy_op->info());
to_create_argument.attributes = to_copy_op->attributes();

VLOG(6) << "start copy op: " << to_copy_op->name();
auto origin_results = to_copy_op->results();
VLOG(6) << "start translate origin results into op type.";
std::transform(origin_results.begin(),
origin_results.end(),
std::back_inserter(to_create_argument.output_types),
[](const pir::OpResult &r) {
// OpResult -> OpType
return r.type();
});

// transform by value_map dict.
VLOG(6) << "start create op.";
auto origin_operands = to_copy_op->operands();
std::transform(origin_operands.begin(),
origin_operands.end(),
std::back_inserter(to_create_argument.inputs),
[&value_map](const pir::OpOperand &operand) {
// Operand -> OpResult
return value_map[operand.source()];
});
auto *cloned_op = Operation::Create(std::move(to_create_argument));

std::vector<int> tmp;
std::transform(origin_results.begin(),
origin_results.end(),
cloned_op->results().begin(),
std::back_inserter(tmp), // NOLINT, just a placeholder.
[&value_map](const OpResult &a, const OpResult &b) { // NOLINT
value_map[a.Value::impl()] = b.Value::impl();
return 1;
});
return cloned_op;
}

std::list<Operation *>::const_iterator list_offset(const Block *block,
int start_idx) {
auto it = block->begin();
Expand Down Expand Up @@ -1057,19 +1016,13 @@ static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
using OpResultMap =
std::pair<std::vector<pir::OpResult>, std::vector<pir::OpResult>>;
std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
const Program &program) {
Program &program) { // NOLINT
// Limitation of this function:
// 1. don't support Parameters.
// 2. don't support Regions in operator.
pir::IrContext *ctx = pir::IrContext::Instance();
auto cloned_program = std::make_shared<Program>(ctx);
std::unordered_map<pir::Value, pir::Value> value_map;
for (auto &op : *program.block()) {
auto *cloned_op = BuildOpFrom(&op, value_map);
cloned_program->block()->push_back(cloned_op);
}
pir::IrMapping mapper;
auto cloned_program = program.Clone(mapper);
std::vector<pir::OpResult> associated_array_key, associated_array_value;
for (auto &pair : value_map) {
for (auto &pair : mapper.Map<pir::Value>()) {
associated_array_key.push_back(pair.first.dyn_cast<pir::OpResult>());
associated_array_value.push_back(pair.second.dyn_cast<pir::OpResult>());
}
Expand Down Expand Up @@ -1178,21 +1131,26 @@ SplitedResult SplitForwardBackward(
std::unordered_set<pir::Value> backward_inputs;
std::tie(middle_values, backward_inputs) = AnalysisMiddleVariable(
program, forward_in_out_values, forward_range, backward_range);
std::unordered_map<pir::Value, pir::Value> forward_value_map;
std::unordered_map<pir::Value, pir::Value> backward_value_map;
pir::Builder backward_builder = pir::Builder(ctx, backward_program->block());
bool has_backward = (backward_range[1] > backward_range[0]);

// forward program construct.
VLOG(4) << "start create forward program.";
range_block_do(program.block(),
forward_range,
[&forward_value_map, &forward_program](Operation *op) {
auto *cloned_op = BuildOpFrom(op, forward_value_map);
forward_program->block()->push_back(cloned_op);
});
pir::IrMapping forward_mapper;
auto clone_options = pir::CloneOptions(true, true);
range_block_do(
program.block(),
forward_range,
[&forward_mapper, &forward_program, &clone_options](Operation *op) {
auto *cloned_op = op->Clone(forward_mapper, clone_options);
forward_program->block()->push_back(cloned_op);
});
auto &forward_value_map = forward_mapper.MutableMap<pir::Value>();

// backward program construc.
// Step1. insert data op for inputs_values and middle_values
pir::IrMapping backward_mapper;
auto &backward_value_map = backward_mapper.MutableMap<pir::Value>();
int counter = 0;
auto create_data_fn = [&backward_builder,
&backward_inputs,
Expand Down Expand Up @@ -1311,12 +1269,13 @@ SplitedResult SplitForwardBackward(

// Step2. copy backward ops .
VLOG(4) << "start copy backward ops";
range_block_do(program.block(),
backward_range,
[&backward_value_map, &backward_program](Operation *op) {
auto *cloned_op = BuildOpFrom(op, backward_value_map);
backward_program->block()->push_back(cloned_op);
});
range_block_do(
program.block(),
backward_range,
[&backward_mapper, &backward_program, &clone_options](Operation *op) {
auto *cloned_op = op->Clone(backward_mapper, clone_options);
backward_program->block()->push_back(cloned_op);
});
// counter = 0;
VLOG(4) << "start create backward outputs, inserting set_parameter ops.";
if (has_backward) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class IR_API Block {
friend class Region;
void SetParent(Region *parent);

// Take out corresponding Operation and its ownershipe.
// Take out corresponding Operation and its ownership.
friend class Operation;
Operation *Take(Operation *op);

Expand Down
61 changes: 53 additions & 8 deletions paddle/pir/core/ir_mapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,69 @@
#pragma once
#include <unordered_map>
#include "paddle/common/enforce.h"
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/value.h"

namespace pir {
class Block;
class Operation;

class IrMapping {
public:
void Add(Value from, Value to) { value_map_[from] = to; }
template <typename T>
void Add(T from, T to) {
if (!from) return;
MutableMap<T>()[from] = to;
}

template <typename T>
T Lookup(T from) const {
if (!from) return static_cast<T>(nullptr);
IR_ENFORCE(Map<T>().count(from) > 0, "Not found key in IRMapping.");
return Map<T>().at(from);
}

template <typename T>
void Earse(T from) {
MutableMap<T>().erase(from);
}

Value Lookup(Value from) const {
IR_ENFORCE(value_map_.count(from) > 0, "Not Found Value in IRMapping.");
return value_map_.at(from);
void Clear() {
value_map_.clear();
block_map_.clear();
operation_map_.clear();
}
void Earse(Value from) { value_map_.erase(from); }

void Clear() { value_map_.clear(); }
template <typename T>
using MapType = std::unordered_map<T, T>;

template <typename T>
const MapType<T> &Map() const {
if constexpr (std::is_convertible<T, Value>::value)
return value_map_;
else if constexpr (std::is_convertible<T, Block *>::value)
return block_map_;
else if constexpr (std::is_convertible<T, Operation *>::value)
return operation_map_;
else
IR_THROW("Not support type in IRMapping.");
}

template <typename T>
MapType<T> &MutableMap() {
if constexpr (std::is_convertible<T, Value>::value)
return value_map_;
else if constexpr (std::is_convertible<T, Block *>::value)
return block_map_;
else if constexpr (std::is_convertible<T, Operation *>::value)
return operation_map_;
else
IR_THROW("Not support type in IRMapping.");
}

private:
std::unordered_map<Value, Value> value_map_;
MapType<Value> value_map_;
MapType<Block *> block_map_;
MapType<Operation *> operation_map_;
};

} // namespace pir
15 changes: 12 additions & 3 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ Operation *Operation::Create(const std::vector<Value> &inputs,
}

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
IR_ENFORCE(!options.IsCloneRegions() || num_regions_ <= 0,
"Operation CloneRegions is unimplemented currently.");
IR_ENFORCE(num_successors_ == 0,
"Operation::Clone is not unimplemented for multiple successors.");

Expand All @@ -156,10 +154,21 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
output_types.push_back(result.type());
}
auto *new_op = Create(inputs, attributes_, output_types, info_, num_regions_);
ir_mapping.Add(this, new_op);

// record outputs mapping info
for (uint32_t i = 0; i < num_results_; ++i) {
ir_mapping.Add(result(i), new_op->result(i));
ir_mapping.Add(static_cast<Value>(result(i)),
static_cast<Value>(new_op->result(i)));
}

if (options.IsCloneRegions()) {
// clone regions recursively
for (uint32_t i = 0; i < num_regions_; ++i) {
this->region(i).CloneInto(new_op->region(i), ir_mapping);
}
}

return new_op;
}

Expand Down
11 changes: 11 additions & 0 deletions paddle/pir/core/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ Program::~Program() {
}
}

std::shared_ptr<Program> Program::Clone(IrMapping& ir_mapping) {
pir::IrContext* ctx = pir::IrContext::Instance();
auto new_program = std::make_shared<Program>(ctx);
auto clone_options = CloneOptions(true, true);
for (auto& op : *block()) {
auto* new_op = op.Clone(ir_mapping, clone_options);
new_program->block()->push_back(new_op);
}
return new_program;
}

Parameter* Program::GetParameter(const std::string& name) const {
if (parameters_.count(name) != 0) {
return parameters_.at(name).get();
Expand Down
3 changes: 3 additions & 0 deletions paddle/pir/core/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/ir_mapping.h"
#include "paddle/pir/core/operation.h"
#include "paddle/pir/core/parameter.h"

Expand Down Expand Up @@ -54,6 +55,8 @@ class IR_API Program {

static std::unique_ptr<Program> Parse(std::istream& is, IrContext* ctx);

std::shared_ptr<Program> Clone(IrMapping& ir_mapping); // NOLINT

Block* block() { return &module_.block(); }
const Block* block() const { return &module_op().block(); }

Expand Down
40 changes: 40 additions & 0 deletions paddle/pir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,46 @@ Region::Iterator Region::erase(ConstIterator position) {
return blocks_.erase(position);
}

void Region::CloneInto(Region &other, IrMapping &ir_mapping) {
if (empty()) {
return;
}
other.clear();
auto clone_options = CloneOptions(false, false);
// clone blocks, block arguments and sub operations
for (auto &block : *this) {
auto new_block = new Block;
ir_mapping.Add(&block, new_block);
for (auto &arg : block.args()) {
ir_mapping.Add(arg, new_block->AddArgument(arg.type()));
}
other.push_back(new_block);
// clone sub operations, but not map operands nor clone regions
for (auto op_iter = block.begin(); op_iter != block.end(); ++op_iter) {
new_block->push_back(op_iter->Clone(ir_mapping, clone_options));
}
}
// after all operation results are mapped, map operands and clone regions.
{
auto iter = begin();
auto new_iter = other.begin();
for (; iter != end(); ++iter, ++new_iter) {
auto op_iter = iter->begin();
auto new_op_iter = new_iter->begin();
for (; op_iter != iter->end(); ++op_iter, ++new_op_iter) {
Operation &op = *op_iter;
Operation &new_op = *new_op_iter;
// operands of new_op are same as op, now map them.
for (uint32_t i = 0; i < op.num_operands(); ++i)
new_op.operand(i).set_source(ir_mapping.Lookup(op.operand_source(i)));
// clone sub regions
for (uint32_t i = 0; i < op.num_regions(); ++i)
op.region(i).CloneInto(new_op.region(i), ir_mapping);
}
}
}
}

std::unique_ptr<pir::Block> Region::TakeBack() {
Block *block = nullptr;
if (!blocks_.empty()) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/pir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>

#include "paddle/pir/core/dll_decl.h"
#include "paddle/pir/core/ir_mapping.h"
#include "paddle/pir/core/iterator.h"
#include "paddle/pir/core/visitors.h"

Expand Down Expand Up @@ -71,6 +72,9 @@ class IR_API Region {
template <WalkOrder Order = WalkOrder::PostOrder, typename FuncT>
void Walk(FuncT &&callback);

// clone this region into another region, target region will be overwritten.
void CloneInto(Region &other, IrMapping &ir_mapping); // NOLINT

// take the last block of region.
// if region is empty, return nullptr;
std::unique_ptr<Block> TakeBack();
Expand Down
Loading

0 comments on commit 47ecd81

Please sign in to comment.