Skip to content

Commit

Permalink
add more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Jul 22, 2020
1 parent 34d4d7b commit 2b5bcea
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions tests/cpp/auto_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ using namespace tvm::auto_scheduler;
TEST(ComputeDAG, AccessAnalyzer) {
const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3);
const auto& dag = tvm::auto_scheduler::ComputeDAG(tensors);
const auto& s0 = dag->init_state;
State s0 = dag->init_state;

int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5;
int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10;
Expand Down Expand Up @@ -144,8 +144,31 @@ TEST(ComputeDAG, AccessAnalyzer) {
}
}

// todo(lmzheng): Add more test cases for GetConsumer and GetProducesr after we have
// compute_inline
s0.compute_inline(bn_add);
s0.compute_inline(bn_mul);
s0.compute_inline(bias_add);
s0.compute_inline(padding);
{
std::vector<std::pair<int, int>> consumer_list = {{data, conv}, {kernel, conv}, {conv, relu}};
for (const auto& pair : consumer_list) {
dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &op_set);
CHECK_EQ(op_set.size(), 1);
CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op);
}
std::vector<std::pair<int, std::vector<int>>> producer_list = {{padding, {data}},
{conv, {padding, kernel}},
{bias_add, {conv, bias}},
{bn_mul, {bias_add, bn_scale}},
{bn_add, {bn_mul, bn_offset}},
{relu, {bn_add}}};
for (const auto& pair : producer_list) {
dag->access_analyzer.GetDirectProducers(s0->stages[pair.first]->op, &op_set);
CHECK_EQ(op_set.size(), pair.second.size());
for (const auto& target : pair.second) {
CHECK(op_set.count(s0->stages[target]->op));
}
}
}
}

int main(int argc, char** argv) {
Expand Down

0 comments on commit 2b5bcea

Please sign in to comment.