From 839e1fe3baec7671d92838b8cb2e20011c066546 Mon Sep 17 00:00:00 2001 From: Mercy Date: Wed, 10 Apr 2019 18:48:12 +0800 Subject: [PATCH] [AutoTVM] fix argument type for curve feature --- src/autotvm/touch_extractor.cc | 4 ++-- tests/python/unittest/test_autotvm_feature.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index e24e757427ba..002b970588ec 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -514,10 +514,10 @@ TVM_REGISTER_API("autotvm.feature.GetItervarFeatureFlatten") TVM_REGISTER_API("autotvm.feature.GetCurveSampleFeatureFlatten") .set_body([](TVMArgs args, TVMRetValue *ret) { Stmt stmt = args[0]; - bool take_log = args[1]; + int sample_n = args[1]; std::vector ret_feature; - GetCurveSampleFeatureFlatten(stmt, take_log, &ret_feature); + GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature); TVMByteArray arr; arr.size = sizeof(float) * ret_feature.size(); diff --git a/tests/python/unittest/test_autotvm_feature.py b/tests/python/unittest/test_autotvm_feature.py index 401a8d3be407..e0736c280dc4 100644 --- a/tests/python/unittest/test_autotvm_feature.py +++ b/tests/python/unittest/test_autotvm_feature.py @@ -61,6 +61,23 @@ def test_iter_feature_gemm(): assert ans[pair[0]] == pair[1:], "%s: %s vs %s" % (pair[0], ans[pair[0]], pair[1:]) +def test_curve_feature_gemm(): + N = 128 + + k = tvm.reduce_axis((0, N), 'k') + A = tvm.placeholder((N, N), name='A') + B = tvm.placeholder((N, N), name='B') + C = tvm.compute( + A.shape, + lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k), + name='C') + + s = tvm.create_schedule(C.op) + + feas = feature.get_buffer_curve_sample_flatten(s, [A, B, C], sample_n=30) + # sample_n * #buffers * #curves * 2 numbers per curve + assert len(feas) == 30 * 3 * 4 * 2 + def test_feature_shape(): """test the dimensions of flatten feature are the same""" @@ -112,4 +129,6 @@ def get_gemm_feature(target): if __name__ == "__main__": test_iter_feature_gemm() + test_curve_feature_gemm() test_feature_shape() +