Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTVM] Mali autotuning fails for ResNet #3980

Closed
mbaret opened this issue Sep 20, 2019 · 23 comments
Closed

[AutoTVM] Mali autotuning fails for ResNet #3980

mbaret opened this issue Sep 20, 2019 · 23 comments

Comments

@mbaret
Copy link
Contributor

mbaret commented Sep 20, 2019

Autotuning on Mali fails for ResNet18_v1 (from gluon) with error:

TVMError: Not all Vars are passed in api_args:  'threadIdx.x'  does not appear in api_args
Error during compile function
-----------------------------
v0.0.4
fn (%p0: Tensor[(1, 512, 7, 7), float32], %p1: Tensor[(1, 512, 7, 7), float32], %p2: Tensor[(512, 512, 3, 3), float32], %p3: Tensor[(512, 1, 1), float32], %p4: Tensor[(512, 1, 1), float32], Primitive=1) -> Tensor[(1, 512, 7, 7), float32] {
  %0 = nn.conv2d(%p1, %p2, padding=[1, 1], channels=512, kernel_size=[3, 3]) /* ty=Tensor[(1, 512, 7, 7), float32] */;
  %1 = multiply(%0, %p3) /* ty=Tensor[(1, 512, 7, 7), float32] */;
  %2 = add(%1, %p4) /* ty=Tensor[(1, 512, 7, 7), float32] */;
  %3 = add(%p0, %2) /* ty=Tensor[(1, 512, 7, 7), float32] */;
  nn.relu(%3) /* ty=Tensor[(1, 512, 7, 7), float32] */
}

3809 and 4072 appear to describe the same issue.

I've been able to resolve this by forcibly disabling the use of tophub fallback configs during task discovery under tvm/python/tvm/autotvm/task/relay_integration.py

This doesn't seem like an elegant solution, but is there any reason why tophub needs to be used during this step? Disabling tophub here would also make it easier to add new targets which aren't yet represented in tophub.

@tqchen
Copy link
Member

tqchen commented Sep 20, 2019

@merrymercy @ZihengJiang can you followup a bit on this?

@merrymercy
Copy link
Member

merrymercy commented Sep 25, 2019

I tried our official benchmark script. And found it is also broken after #3368.

The error message is

[23:40:21] /root/tvm/src/pass/loop_partition.cc:541: Cannot prove: ((((((((((((blockIdx.x*8) + threadIdx.x) % 16)/4)*4) + (threadIdx.x % 4)) + 1) - (((blockIdx.x*8) + threadIdx.x) % 16)) - 1) - 1) + 1) >= 0), when generating the post do
ubt loop
Traceback (most recent call last):
  File "mobile_gpu_imagenet_bench.py", line 105, in <module>
    evaluate_network(network, target, target_host, args.dtype, args.repeat)
  File "mobile_gpu_imagenet_bench.py", line 45, in evaluate_network
    shape={'data': input_shape}, params=params, dtype=dtype)
  File "/root/tvm/nnvm/python/nnvm/compiler/build_module.py", line 321, in build
    graph = graph.apply("GraphCompile")
  File "/root/tvm/nnvm/python/nnvm/graph.py", line 250, in apply
    check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
  File "/root/tvm/nnvm/python/nnvm/_base.py", line 91, in check_call
    raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: TVMError: Not all Vars are passed in api_args:  'threadIdx.x'  does not appear in api_args

During handling of the above exception, another exception occurred:

TVMError: Not all Vars are passed in api_args:  'threadIdx.x'  does not appear in api_args
Error during compile graph
--------------------------
Graph(%input0, %input1, %input2) {
  %input0, shape=[1,512,7,7]
  %input1, shape=[4,4,128,512,4]
  %input2, shape=[1,512,7,7]
  %2 = _contrib_conv2d_winograd_without_weight_transform(%input0, %input1, out_dtype='same', kernel_layout='OIHW', tile_size='2', groups='1', padding='[1,1]', out_layout='__undef__', strides='[1,1]', layout='NCHW', kernel_size='[3,3]',
channels='512', use_bias='0', dilation='[1,1]'), shape=[1,512,7,7]
  %4 = elemwise_add(%2, %input2), shape=[1,512,7,7]
  ret %4
}
graph_attr_keys = [shape, shape_num_unknown_nodes, dtype, dtype_num_unknown_nodes]

It seems to be very similar to the problem in
https://discuss.tvm.ai/t/topi-winograd-test-topi-conv2d-winograd-py-fails-for-a-given-shape/4107/5. And I guess @zhiics will look into it later.

@mbarrett97 Does your error message also look like this? How to reproduce your error?

Reproduce my error without mali

git reset --hard 4273e46
git submodule init && git submodule update --recursive
rm -rf ~/.tvm
# build

Apply diff

diff --git a/apps/benchmark/mobile_gpu_imagenet_bench.py b/apps/benchmark/mobile_gpu_imagenet_bench.py
index c889b3d..be6051d 100644
--- a/apps/benchmark/mobile_gpu_imagenet_bench.py
+++ b/apps/benchmark/mobile_gpu_imagenet_bench.py
@@ -31,8 +31,9 @@ from util import get_network, print_progress

 def evaluate_network(network, target, target_host, dtype, repeat):
     # connect to remote device
-    tracker = tvm.rpc.connect_tracker(args.host, args.port)
-    remote = tracker.request(args.rpc_key)
+#    tracker = tvm.rpc.connect_tracker(args.host, args.port)
+#    remote = tracker.request(args.rpc_key)
+    remote = tvm

     print_progress(network)
     net, params, input_shape, output_shape = get_network(network, batch_size=1, dtype=dtype)
python3 apps/benchmark/mobile_gpu_imagenet_bench.py --rpc-key aha --network resnet-18

@zhiics
Copy link
Member

zhiics commented Sep 26, 2019

I found that the shapes are not simplified here for some convs

https://github.com/dmlc/tvm/blob/d21f0ad5d946ae569ee9521578571a9d90e2d211/src/pass/storage_flatten.cc#L171-L175

I am not sure why threadIdx.x and blockIdx.x are not replaced with the corresponding bounds even I ran simplifier there. @kimishpatel have you seen this as well?

@eqy
Copy link
Contributor

eqy commented Sep 26, 2019

I tried our official benchmark script. And found it is also broken after #3368.

The error message is

[23:40:21] /root/tvm/src/pass/loop_partition.cc:541: Cannot prove: ((((((((((((blockIdx.x*8) + threadIdx.x) % 16)/4)*4) + (threadIdx.x % 4)) + 1) - (((blockIdx.x*8) + threadIdx.x) % 16)) - 1) - 1) + 1) >= 0), when generating the post do
ubt loop
Traceback (most recent call last):
  File "mobile_gpu_imagenet_bench.py", line 105, in <module>
    evaluate_network(network, target, target_host, args.dtype, args.repeat)
  File "mobile_gpu_imagenet_bench.py", line 45, in evaluate_network
    shape={'data': input_shape}, params=params, dtype=dtype)
  File "/root/tvm/nnvm/python/nnvm/compiler/build_module.py", line 321, in build
    graph = graph.apply("GraphCompile")
  File "/root/tvm/nnvm/python/nnvm/graph.py", line 250, in apply
    check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
  File "/root/tvm/nnvm/python/nnvm/_base.py", line 91, in check_call
    raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: TVMError: Not all Vars are passed in api_args:  'threadIdx.x'  does not appear in api_args

During handling of the above exception, another exception occurred:

TVMError: Not all Vars are passed in api_args:  'threadIdx.x'  does not appear in api_args
Error during compile graph
--------------------------
Graph(%input0, %input1, %input2) {
  %input0, shape=[1,512,7,7]
  %input1, shape=[4,4,128,512,4]
  %input2, shape=[1,512,7,7]
  %2 = _contrib_conv2d_winograd_without_weight_transform(%input0, %input1, out_dtype='same', kernel_layout='OIHW', tile_size='2', groups='1', padding='[1,1]', out_layout='__undef__', strides='[1,1]', layout='NCHW', kernel_size='[3,3]',
channels='512', use_bias='0', dilation='[1,1]'), shape=[1,512,7,7]
  %4 = elemwise_add(%2, %input2), shape=[1,512,7,7]
  ret %4
}
graph_attr_keys = [shape, shape_num_unknown_nodes, dtype, dtype_num_unknown_nodes]

It seems to be very similar to the problem in
https://discuss.tvm.ai/t/topi-winograd-test-topi-conv2d-winograd-py-fails-for-a-given-shape/4107/5. And I guess @zhiics will look into it later.

@mbarrett97 Does your error message also look like this? How to reproduce your error?

Reproduce my error without mali

git reset --hard 4273e46
git submodule init && git submodule update --recursive
rm -rf ~/.tvm
# build

Apply diff

diff --git a/apps/benchmark/mobile_gpu_imagenet_bench.py b/apps/benchmark/mobile_gpu_imagenet_bench.py
index c889b3d..be6051d 100644
--- a/apps/benchmark/mobile_gpu_imagenet_bench.py
+++ b/apps/benchmark/mobile_gpu_imagenet_bench.py
@@ -31,8 +31,9 @@ from util import get_network, print_progress

 def evaluate_network(network, target, target_host, dtype, repeat):
     # connect to remote device
-    tracker = tvm.rpc.connect_tracker(args.host, args.port)
-    remote = tracker.request(args.rpc_key)
+#    tracker = tvm.rpc.connect_tracker(args.host, args.port)
+#    remote = tracker.request(args.rpc_key)
+    remote = tvm

     print_progress(network)
     net, params, input_shape, output_shape = get_network(network, batch_size=1, dtype=dtype)
python3 apps/benchmark/mobile_gpu_imagenet_bench.py --rpc-key aha --network resnet-18

I also encountered this issue in some recent runs on Mali with the tutorial script. All workloads except for the last layer in resnet18 have this problem.

@kimishpatel
Copy link
Contributor

I found that the shapes are not simplified here for some convs

https://github.com/dmlc/tvm/blob/d21f0ad5d946ae569ee9521578571a9d90e2d211/src/pass/storage_flatten.cc#L171-L175

I am not sure why threadIdx.x and blockIdx.x are not replaced with the corresponding bounds even I ran simplifier there. @kimishpatel have you seen this as well?

@zhiics, can you try to Simplify the e->extent via bounded_analyzer_? (https://github.com/dmlc/tvm/blob/d21f0ad5d946ae569ee9521578571a9d90e2d211/src/pass/storage_flatten.cc#L427) I have not run into this error since I was mostly working with x86 CPU, but I think it is reasonable to assume the analyzer I added should be able to address this. If not we should look into augmenting bounded_analyzer. I am not sure if we should be overly eager and simplify all bounds but that might helps us cover undiscovered ground.
BTW, I am gonna be offline for sometime (probably a month), so my responses might be delayed significantly.

@zhiics
Copy link
Member

zhiics commented Sep 26, 2019

@kimishpatel Thanks for your response. I tried it but it didn't work.

@kimishpatel
Copy link
Contributor

@zhiics, hmm., ir_visitor_with_analyzer right now visits only a few exprs, like For and something else. It might be that to simplify r->extent it needs more? I am not sure right now, but thats the path I can suggest.

@zhiics
Copy link
Member

zhiics commented Sep 26, 2019

Yes, that's also something I think as well as per https://discuss.tvm.ai/t/topi-winograd-test-topi-conv2d-winograd-py-fails-for-a-given-shape/4107/5

@kimishpatel
Copy link
Contributor

@zhiics, oh absolutely, this makes total sense. Realize also has bounds that we should populate as you suggested on the thread, and for that matter other Stmts too.

@tqchen
Copy link
Member

tqchen commented Sep 26, 2019

This error reveals two things to be improved:

  • Improve the simplifier
  • Improve the storage rewriter itself to have more clear error message when the storage scope is index dependent.

@zhiics
Copy link
Member

zhiics commented Sep 26, 2019

@tqchen @kimishpatel Yeah, I am actually trying to improve the simplifier to cover more exprs. One thing I am not quite sure for Realize is how to bind it.

For item 2, I think we can probably log warning when sz or comb_size is not constant, right?

https://github.com/dmlc/tvm/blob/d7998d398acaa495952e3946d21b72bcfbae6385/src/pass/storage_rewrite.cc#L579

@zhiics
Copy link
Member

zhiics commented Sep 27, 2019

@tqchen I now found where the problem was started. The issue is actually because the bound is not correctly simplified during InferBound. The old version of the simplifier can simplify the following bound to 1

(((((((((((blockIdx.x*128) + threadIdx.x)/16)/16)/64)*16)16) + ((((4 + (((((blockIdx.x128) + threadIdx.x)/16) % 16)*4)) - 1)/4)16)) + (((4 + ((((blockIdx.x128) + threadIdx.x) %16)4)) - 1)/4)) + 1) - (((((((((blockIdx.x128) + threadIdx.x)/16)/16)/64)*16)16) + (((((((blockIdx.x128) + threadIdx.x)/16) % 16)*4)/4)16)) + (((((blockIdx.x128) + threadIdx.x) % 16)*4)/4)))

but the new simplifier here

https://github.com/dmlc/tvm/blob/01e5393574c3bda90b6029e3ef76f24f839f0c9c/src/arithmetic/stmt_simplify.cc#L105-L112

can only simplify it to:

(((((((((((blockIdx.x*128) + threadIdx.x) % 256)/16)*4) + 3)/4)16) + ((((((blockIdx.x128) + threadIdx.x) % 16)4) + 3)/4)) + 1) - (((blockIdx.x128) + threadIdx.x) % 256))

And then this bound will be propagated down to all other passes and eventually cause the mentioned error.

Could you please take a look to see or advise why it is not simplified? Thanks.

@tqchen
Copy link
Member

tqchen commented Sep 27, 2019

The main problem is likely due to the fact that we dd not fully bind the range of the IterVar. This is a new requirement because we are being strict about the division mode and needs to know the sign of the expression.

@zhiics
Copy link
Member

zhiics commented Sep 27, 2019

Thanks.

I tried to change kv.second to Range(0, kv.second->extent). It seemed not working probably because the IterVar is not bound?

https://github.com/dmlc/tvm/blob/01e5393574c3bda90b6029e3ef76f24f839f0c9c/src/arithmetic/stmt_simplify.cc#L105-L112

So this will then be considered together with https://discuss.tvm.ai/t/discuss-embed-more-bound-information-into-var-or-expr/4079 ?

@tqchen
Copy link
Member

tqchen commented Sep 27, 2019

Upgrading the simplification will alleviate the problem, but would still be great for us to improve the bound context in the current setting. Both are directions for improvements

@zhiics
Copy link
Member

zhiics commented Sep 27, 2019

Thanks. I think I probably need much more time to work on it because I need to understand the simplifier first. Will you have cycles to have a fix? It is currently blocking our internal merge.

@tqchen
Copy link
Member

tqchen commented Sep 27, 2019

It would be great if you can look into it a bit, so that we have more developers who can hack into the simplifier :) reducing to unit-test cases (like the case you suggested) might help the process.

@zhiics
Copy link
Member

zhiics commented Sep 28, 2019

@tqchen I dug deeper into the simplifier. Should we relax the following

https://github.com/dmlc/tvm/blob/f98035b093112ce5dfdde518c86b1511830f7172/src/arithmetic/rewrite_simplify.cc#L544-L550

by removing CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)

I think (a*x + y) / b = x * (a / b) + (y / b) holds if a is a multiple of b for any sign, isn't this true?

@yongwww
Copy link
Member

yongwww commented Sep 29, 2019

@tqchen @zhiics my prove for the equation above.
a, x, y, and b are all integer numbers, a is a multiple of b. So a = b * e, e is an integer.

1). if y is a multiple of b, then y = b * f, f is an integer.

(a*x + y) / b
= (b*e*x + b*f)/b
= b*(e*x+f)/b
= e*x+f

x * (a / b) + (y / b) 
= x * (b*e)/b + (b*f)/b
= e*x+f

so (a*x + y) / b = x * (a / b) + (y / b) = e*x+f

2). if y is not a multiple of b, y can be denoted by g*b + h, here g is an integer, h here is an positive integer less than |b|.

(a*x + y) / b
= (b*e*x + g*b + h)/b
= ( (e*x + g)*b + h)/b 
=  (e*x + g)*b)/b, Since `b*(e*x + g) ` is multiple of `b`, `h` is less than `b`
= e*x +g


x * (a / b) + (y / b)
= x * (b*e/b) + (g*b + h)/b 
= x*e + (g*b+h)/b
= x*e + g

So, (a*x + y) / b = x * (a / b) + (y / b) 

To sum up, no matter whether b is a multiple of b or not, we have (a*x + y) / b = x * (a / b) + (y / b)

@tqchen
Copy link
Member

tqchen commented Sep 29, 2019

This is something has to do with the division semantics. Because the semantics is truncdiv in here.

b = 2, a * x = 8, y = -3. The change of sign will make the result being different because of the truncdiv. see more about the division modes: #3977

@zhiics
Copy link
Member

zhiics commented Sep 29, 2019

@mbarrett97 This problem is caused by the used division mode because division in the schedule will use truncdiv in the TVM IR which needs context information about the sign. Tianqi's pending PR #4008 migrate indexdiv to floordiv. It solves my minimal example to reproduce the problem. Can you check if it solves your problem here as well?

@mbaret
Copy link
Contributor Author

mbaret commented Sep 30, 2019

It does appear to resolve my issue. However, I'm now getting a new issue when I try to run InceptionV3 on Mali:

TVMError: Check failed: pb->value != 0 (0 vs. 0) : Divide by zero
Error during compile function
-----------------------------
v0.0.4
fn (%p0: Tensor[(1, 2048, 5, 5), float32], Primitive=1) -> Tensor[(1, 2048, 1, 1), float32] {
  nn.avg_pool2d(%p0, pool_size=[8, 8], strides=[8, 8], count_include_pad=True) /* ty=Tensor[(1, 2048, 1, 1), float32] */
}

This looks like it could potentially be related to the changes (?), I'll see if I can find a minimal example to reproduce it.

@tqchen
Copy link
Member

tqchen commented Nov 27, 2019

close due to inactive status, feel free to bring new trouble shooting threads to the discuss forum

@tqchen tqchen closed this as completed Nov 27, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants