Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#63 from Superjomn/refactor/as
Browse files Browse the repository at this point in the history
refactor Object.As to as to remind it lack type check
  • Loading branch information
Superjomn authored Mar 4, 2020
2 parents 5a59e6e + 0772634 commit 3f0371a
Show file tree
Hide file tree
Showing 14 changed files with 42 additions and 41 deletions.
4 changes: 2 additions & 2 deletions cinn/common/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ struct Object {

//! Cast to a derived type.
template <typename T>
T* As() {
T* as() {
return static_cast<T*>(this);
}

//! Cast to a derived type.
template <typename T>
const T* As() const {
const T* as() const {
return static_cast<const T*>(this);
}

Expand Down
4 changes: 2 additions & 2 deletions cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ std::vector<Expr *> IfThenElse::expr_fields() { return {&condition, &true_case,
std::vector<const Expr *> IfThenElse::expr_fields() const { return {&condition, &true_case, &false_case}; }

Expr Store::Make(Expr buffer, Expr value, Expr index) {
CHECK(buffer->As<_Buffer_>()) << "buffer should be _Buffer_ type";
CHECK(buffer.As<_Buffer_>()) << "buffer should be _Buffer_ type";
auto node = make_shared<Store>();
node->buffer = buffer;
node->value = value;
Expand Down Expand Up @@ -320,7 +320,7 @@ Var &Var::operator=(const _Var_ *x) {
}

Load::Load(Expr buffer, Expr index) : ExprNode<Load>(buffer->type().ElementOf()), buffer(buffer), index(index) {
CHECK(buffer->As<_Buffer_>());
CHECK(buffer.As<_Buffer_>());
CHECK(index->type() == Int(32));
}

Expand Down
2 changes: 1 addition & 1 deletion cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ class Range : public IrNodeRef {
Range() = default;
Range(IrNodeRef n) : IrNodeRef(n) {}
Range(_Range_* n);
_Range_* operator->() const { return get()->As<_Range_>(); }
_Range_* operator->() const { return get()->as<_Range_>(); }
};

class _Range_ : public ExprNode<_Range_> {
Expand Down
4 changes: 2 additions & 2 deletions cinn/ir/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ void IrPrinter::Visit(const Select *x) {
os_ << ")";
}
void IrPrinter::Visit(const Load *x) {
auto *node = x->buffer->As<ir::_Buffer_>();
auto *node = x->buffer.As<ir::_Buffer_>();
CHECK(node);
os_ << node->tensor_addr << "[";
Print(x->index);
os_ << "]";
}
void IrPrinter::Visit(const Store *x) {
auto *buffer_node = x->buffer->As<ir::_Buffer_>();
auto *buffer_node = x->buffer.As<ir::_Buffer_>();
CHECK(buffer_node->node_type() == ir::_Buffer_::_node_type_);
CHECK(buffer_node);
CHECK(buffer_node->tensor_addr.defined());
Expand Down
11 changes: 6 additions & 5 deletions cinn/ir/lowered_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void _LoweredFunc_::PrepareBufferCastExprs() {
auto buffers = CollectAllBufferReference();
VLOG(3) << "Function used " << buffers.size() << " buffers";
for (auto& b : buffers) {
auto* node = b->As<ir::_Buffer_>();
auto* node = b.As<ir::_Buffer_>();
CHECK(node);
std::string buffer_name = b->name;
std::string tensor_name = BufferGetTensorName(node);
Expand All @@ -81,15 +81,16 @@ void _LoweredFunc_::PrepareBufferCastExprs() {
}
}

std::vector<Buffer> _LoweredFunc_::CollectAllBufferReference() {
auto buffer_exprs = ir::CollectIRNodes(body, [](const Expr* expr) { return expr->As<ir::_Buffer_>(); });
std::vector<Buffer> _LoweredFunc_::CollectAllBufferReference() const {
std::set<Expr> buffer_exprs = ir::CollectIRNodes(body, [](const Expr* expr) { return expr->As<ir::_Buffer_>(); });

std::vector<Buffer> buffers;
// remove the duplicate buffer by their name.
std::set<std::string> names;

for (auto& expr : buffer_exprs) {
Buffer b(expr->As<_Buffer_>());
for (const Expr& expr : buffer_exprs) {
Expr& _expr = *const_cast<Expr*>(&expr);
Buffer b(_expr.As<_Buffer_>());
if (names.count(b->name)) continue;
buffers.push_back(b);
names.insert(b->name);
Expand Down
2 changes: 1 addition & 1 deletion cinn/ir/lowered_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> {
void PrepareBufferCastExprs();
//! Get all the Buffers the function body references.
//! NOTE it will return the buffers with duplicates removed(by comparing their name).
std::vector<Buffer> CollectAllBufferReference();
std::vector<Buffer> CollectAllBufferReference() const;
};

} // namespace ir
Expand Down
4 changes: 2 additions & 2 deletions cinn/lang/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class Buffer {
public:
explicit Buffer(Type type, const std::string& name = "");

ir::_Buffer_* operator->() { return buffer_->As<ir::_Buffer_>(); }
const ir::_Buffer_* operator->() const { return buffer_->As<ir::_Buffer_>(); }
ir::_Buffer_* operator->() { return buffer_.As<ir::_Buffer_>(); }
const ir::_Buffer_* operator->() const { return buffer_.As<ir::_Buffer_>(); }

ir::Buffer buffer() const { return buffer_; }

Expand Down
4 changes: 2 additions & 2 deletions cinn/lang/compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ TEST(Compute, basic) {

ir::Tensor y = Compute(
{100, 100}, [=](Var i, Var j) -> Expr { return x(i, j) + 1.f; }, "y");
LOG(INFO) << "compute: " << y->operaion->As<ir::ComputeOp>()->body[0];
LOG(INFO) << "compute: " << y->operaion->as<ir::ComputeOp>()->body[0];

ir::Tensor z = Compute(
{100, 100}, [=](Var i, Var j) -> Expr { return y(i, j) * 2.f; }, "z");

lang::Buffer z_buffer(Float(32));
z->Bind(z_buffer);

LOG(INFO) << "z: " << z->operaion->As<ir::ComputeOp>()->body[0];
LOG(INFO) << "z: " << z->operaion->as<ir::ComputeOp>()->body[0];

auto schedule = poly::CreateSchedule(z);
LOG(INFO) << "group: " << schedule->gened_groups().size();
Expand Down
6 changes: 3 additions & 3 deletions cinn/lang/lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ struct WriteTeller : public ir::IRMutator<const Expr*> {
void Visit(const Expr* expr, const Expr* op) override { IRMutator::Visit(expr, op); }

void Visit(const ir::Load* expr, const Expr* op) override {
auto* node = expr->As<ir::Load>();
auto* node = op->As<ir::Load>();
CHECK(node);
auto* buffer = node->buffer->As<ir::_Buffer_>();
auto* buffer = node->buffer.As<ir::_Buffer_>();
CHECK(buffer);
buffer_written.insert(buffer->name);
IRMutator::Visit(expr, op);
Expand All @@ -76,7 +76,7 @@ std::vector<ir::Argument> PrepareArguments(const std::vector<Tensor>& tensors, c

for (auto& tensor : tensors) {
bool is_input = teller.buffer_written.count(tensor->name);
auto* tensor_node = tensor->As<ir::_Tensor_>();
auto* tensor_node = tensor.As<ir::_Tensor_>();
args.emplace_back(ir::TensorGetBufferName(tensor_node),
ir::Argument::Kind::kBuffer,
tensor->type().ElementOf(),
Expand Down
4 changes: 2 additions & 2 deletions cinn/lang/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
namespace cinn {
namespace lang {

_Module_ *Module::self() { return module_->As<_Module_>(); }
const _Module_ *Module::self() const { return module_->As<_Module_>(); }
_Module_ *Module::self() { return module_->as<_Module_>(); }
const _Module_ *Module::self() const { return module_->as<_Module_>(); }

Module::Module(const std::string &name, const Target &target) : module_(make_shared<_Module_>()) {
self()->name = name;
Expand Down
28 changes: 14 additions & 14 deletions cinn/lang/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Tensor _Tensor_::Make(const std::string &name,
CHECK(!name.empty()) << "Tensor name is set empty";

auto op = ComputeOp::Make(name, tag, attrs, axis, body, shape);
auto *compute_op = op->As<ComputeOp>();
auto *compute_op = op->as<ComputeOp>();

CHECK_EQ(axis.size(), shape.size()) << "axis not match the dimension in shape";
compute_op->axis = axis;
Expand Down Expand Up @@ -95,7 +95,7 @@ Expr Tensor::operator()(const std::vector<Expr> &indices) const {

const char *_Tensor_::operation_type() const {
if (!operaion.defined()) return "";
return operaion->As<ir::_Operation_>()->func_type();
return operaion->as<ir::_Operation_>()->func_type();
}

bool _Tensor_::is_compute_node() const { return std::strcmp(operation_type(), ir::ComputeOp::__func_type__) == 0; }
Expand All @@ -105,12 +105,12 @@ bool _Tensor_::is_placeholder_node() const {

ComputeOp *_Tensor_::get_compute_op() const {
if (!is_compute_node()) return nullptr;
return operaion->As<ComputeOp>();
return operaion->as<ComputeOp>();
}

PlaceholderOp *_Tensor_::get_placeholder_op() const {
if (!is_placeholder_node()) return nullptr;
return operaion->As<PlaceholderOp>();
return operaion->as<PlaceholderOp>();
}

void _Tensor_::InitStage() {
Expand All @@ -127,9 +127,9 @@ void _Tensor_::InitStage() {
if (stage_shared) return;
stage_shared = new Shared<poly::Stage>;
auto &shared_stage = *static_cast<Shared<poly::Stage> *>(stage_shared);
auto *op = operaion->As<_Operation_>();
auto *op = operaion->as<_Operation_>();
if (is_compute_node()) {
auto &body = op->As<ComputeOp>()->body;
auto &body = op->as<ComputeOp>()->body;
CHECK_EQ(body.size(), 1UL) << "only support functional programming";
shared_stage = make_shared<poly::Stage>(GenerateIslDomain(), body.front());
} else {
Expand All @@ -146,7 +146,7 @@ void _Tensor_::DropStage() {

poly::Stage *_Tensor_::stage() {
if (!stage_shared) return nullptr;
return (*static_cast<Shared<poly::Stage> *>(stage_shared))->As<poly::Stage>();
return (*static_cast<Shared<poly::Stage> *>(stage_shared))->as<poly::Stage>();
}

void _Tensor_::InitAxis() {
Expand All @@ -167,14 +167,14 @@ isl::set _Tensor_::GenerateIslDomain() {
}
std::vector<Expr *> _Tensor_::expr_fields() {
std::vector<Expr *> res;
const char *func_type = operaion->As<ir::_Operation_>()->func_type();
const char *func_type = operaion->as<ir::_Operation_>()->func_type();
if (operaion.defined()) {
if (func_type == ir::ComputeOp::__func_type__) {
auto *op = operaion->As<ir::ComputeOp>();
auto *op = operaion->as<ir::ComputeOp>();
for (auto &expr : op->body) res.push_back(&expr);
for (auto &expr : op->shape) res.push_back(&expr);
} else if (func_type == ir::PlaceholderOp::__func_type__) {
auto *op = operaion->As<ir::PlaceholderOp>();
auto *op = operaion->as<ir::PlaceholderOp>();
for (auto &expr : op->shape) res.push_back(&expr);
} else {
NOT_IMPLEMENTED
Expand All @@ -185,14 +185,14 @@ std::vector<Expr *> _Tensor_::expr_fields() {

std::vector<const Expr *> _Tensor_::expr_fields() const {
std::vector<const Expr *> res;
const char *func_type = operaion->As<ir::_Operation_>()->func_type();
const char *func_type = operaion->as<ir::_Operation_>()->func_type();
if (operaion.defined()) {
if (is_compute_node()) {
auto *op = operaion->As<ir::ComputeOp>();
auto *op = operaion->as<ir::ComputeOp>();
for (auto &expr : op->body) res.push_back(&expr);
for (auto &expr : op->shape) res.push_back(&expr);
} else if (is_placeholder_node()) {
auto *op = operaion->As<ir::PlaceholderOp>();
auto *op = operaion->as<ir::PlaceholderOp>();
for (auto &expr : op->shape) res.push_back(&expr);
} else {
LOG(ERROR) << "func_type: " << func_type;
Expand All @@ -213,7 +213,7 @@ _Operation_ *Operation::operator->() { return static_cast<_Operation_ *>(get());

Expr _Tensor_::body() const {
if (is_placeholder_node()) return Expr();
if (is_compute_node()) return operaion->As<ir::ComputeOp>()->body.front();
if (is_compute_node()) return operaion->as<ir::ComputeOp>()->body.front();
NOT_IMPLEMENTED;
}

Expand Down
2 changes: 1 addition & 1 deletion cinn/optim/replace_call_with_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct ReplaceCallWithExprModifier : public ir::IRMutator<> {

private:
void Visit(const ir::Call *expr, Expr *op) override {
auto *node = expr->As<ir::Call>();
auto *node = op->As<ir::Call>();
CHECK(!node->name.empty()) << "Call has no name";
VLOG(2) << "Processing Call node " << node->name;
if (statement_ != node->name) return;
Expand Down
6 changes: 3 additions & 3 deletions cinn/poly/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ PolyScheduler &PolyScheduler::After(const Stage &a, const Stage &b, int level) {

common::GraphEdge *a_edge, *b_edge;
std::tie(a_edge, b_edge) = a_node->LinkTo<ScheduleGraphEdge>(b_node);
a_edge->As<ScheduleGraphEdge>()->level = level;
b_edge->As<ScheduleGraphEdge>()->level = level;
a_edge->as<ScheduleGraphEdge>()->level = level;
b_edge->as<ScheduleGraphEdge>()->level = level;
return *this;
}

Expand All @@ -153,7 +153,7 @@ std::map<std::string, isl::map> PolyScheduler::BuildSchedule() const {
ScheduleGraph::edge_order_t edge_order;
std::tie(node_order, edge_order) = schedule_graph_.topological_order();
for (auto *edge : edge_order) {
auto *schedule_edge = edge->As<ScheduleGraphEdge>();
auto *schedule_edge = edge->as<ScheduleGraphEdge>();
auto *a_node = schedule_graph_.RetriveNode(edge->source()->As<ScheduleGraphNode>()->time_schedule.id())
->As<ScheduleGraphNode>();
auto *b_node =
Expand Down
2 changes: 1 addition & 1 deletion cinn/poly/stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ std::vector<std::string> Stage::input_statements() const {
auto call_exprs = ir::CollectIRNodes(expr_, [](const Expr *x) { return x->As<ir::Call>(); });
std::set<std::string> statements;
for (auto &expr : call_exprs) {
auto call_name = expr->As<ir::Call>()->name;
auto call_name = expr.As<ir::Call>()->name;
if (call_name != id()) statements.insert(call_name);
}
return std::vector<std::string>(statements.begin(), statements.end());
Expand Down

0 comments on commit 3f0371a

Please sign in to comment.