diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc index 9aab47855bb40..f21fe1f5c57b3 100644 --- a/tests/cpp/auto_scheduler_test.cc +++ b/tests/cpp/auto_scheduler_test.cc @@ -68,7 +68,7 @@ using namespace tvm::auto_scheduler; TEST(ComputeDAG, AccessAnalyzer) { const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); const auto& dag = tvm::auto_scheduler::ComputeDAG(tensors); - const auto& s0 = dag->init_state; + State s0 = dag->init_state; int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10; @@ -144,8 +144,31 @@ TEST(ComputeDAG, AccessAnalyzer) { } } - // todo(lmzheng): Add more test cases for GetConsumer and GetProducesr after we have - // compute_inline + s0.compute_inline(bn_add); + s0.compute_inline(bn_mul); + s0.compute_inline(bias_add); + s0.compute_inline(padding); + { + std::vector> consumer_list = {{data, conv}, {kernel, conv}, {conv, relu}}; + for (const auto& pair : consumer_list) { + dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &op_set); + CHECK_EQ(op_set.size(), 1); + CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = {{padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; + for (const auto& pair : producer_list) { + dag->access_analyzer.GetDirectProducers(s0->stages[pair.first]->op, &op_set); + CHECK_EQ(op_set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(op_set.count(s0->stages[target]->op)); + } + } + } } int main(int argc, char** argv) {