Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Relay][Keras] fix a wrong assertion about the kernal_layout of DepthwiseConv2D #15124

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

jikechao
Copy link
Contributor

@jikechao jikechao commented Jun 20, 2023

This PR fixed a wrong assertion.
The assertion is conflicted with the related code in Line 377-379 of the file python/tvm/relay/frontend/keras.py
For the DepthwiseConv2D operator, if the input_layout = "NHWC", the kernal layout should be "HWOI" rather than "HWIO".

This PR fixed it and add a bug-triggering test case.

image

Steps to reproduce

import tvm
import tvm.relay as relay
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, models
input_shape = (2, 32,32, 1)
input_data = np.random.random(size=input_shape)
x = layers.Input(shape=input_shape[1:], dtype='float32')

kwargs={'strides':[2,2],'kernel_size':[2, 2],}
layer = keras.layers.DepthwiseConv2D (**kwargs)
layer.set_weights(layer.get_weights())

y = layer(x)
model = models.Model(x, y)
model.summary()
res_keras = model(input_data)

shape_dict = {'input_1': input_shape}
mod, params = relay.frontend.from_keras(model, shape_dict,layout='NHWC')
with tvm.transform.PassContext(opt_level=3):
    model = relay.build_module.create_executor("graph", mod, tvm.cpu(0), 'llvm', params).evaluate()

test_x_tvm = input_data
res_tvm = model(tvm.nd.array(test_x_tvm.astype('float32'))).numpy()

np.testing.assert_allclose(res_keras, res_tvm, atol=1e-3, rtol=1e-3)

Crash Traceback

Traceback (most recent call last):
  36: TVMFuncCall
  35: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::vm::VMCompiler::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  34: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&)
  33: tvm::relay::vm::VMCompiler::LowerImpl(tvm::IRModule)
  32: tvm::relay::vm::VMCompiler::OptimizeModuleImpl(tvm::IRModule)
  31: tvm::transform::Pass::operator()(tvm::IRModule) const
  30: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  29: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  28: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  27: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  26: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::AlterOpLayout()::$_1>(tvm::relay::transform::AlterOpLayout()::$_1)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  25: tvm::relay::alter_op_layout::AlterOpLayout(tvm::RelayExpr const&)
  24: tvm::relay::ForwardRewrite(tvm::RelayExpr const&, tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)> const&, std::function<tvm::runtime::ObjectRef (tvm::relay::Call const&)>, std::function<tvm::RelayExpr (tvm::RelayExpr const&)>)
  23: tvm::relay::ForwardRewriter::Rewrite(tvm::RelayExpr const&)
  22: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
  21: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
  20: _ZN3tvm5relay1
  19: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  18: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  17: tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const
  16: _ZZN3tvm5relay11ExprFunc
  15: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  14: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
  13: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
  12: _ZN3tvm5relay1
  11: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  10: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  9: tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const
  8: _ZZN3tvm5relay11ExprFunc
  7: _ZN3tvm5relay1
  6: tvm::RelayExpr tvm::relay::MixedModeMutator::Rewrite<tvm::relay::CallNode>(tvm::relay::CallNode const*)
  5: tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&)
  4: _ZN3tvm7runtime13Pac
  3: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  2: tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)
  1: tvm::relay::alter_op_layout::AlterTransformMemorizerNode::CallWithNewLayouts(tvm::relay::Call const&, tvm::Attrs, std::vector<tvm::RelayExpr, std::allocator<tvm::RelayExpr> > const&)
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::$_2> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  File "/workplace/software/tvm/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 82, in cfun
    rv = local_pyfunc(*pyargs)
  File "/workplace/software/tvm/tvm/python/tvm/relay/op/nn/_nn.py", line 213, in alter_op_layout_conv2d
    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
  File "<decorator-gen-53>", line 2, in conv2d_alter_layout
  File "/workplace/software/tvm/tvm/python/tvm/target/generic_func.py", line 286, in dispatch_func
    return dispatch_dict[k](*args, **kwargs)
  File "/workplace/software/tvm/tvm/python/tvm/topi/x86/conv2d_alter_op.py", line 60, in _alter_conv2d_layout
    impl, outs = relay.backend.te_compiler.select_implementation(
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/te_compiler.py", line 177, in select_implementation
    all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/te_compiler.py", line 118, in get_valid_implementations
    strategy = fstrategy(attrs, inputs, out_type, target)
  File "/workplace/software/tvm/tvm/python/tvm/target/generic_func.py", line 46, in __call__
    return _ffi_api.GenericFuncCallFunc(self, *args)
  File "/workplace/software/tvm/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 238, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::$_5> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::GenericFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::$_2> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  File "/workplace/software/tvm/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 82, in cfun
    rv = local_pyfunc(*pyargs)
  File "/workplace/software/tvm/tvm/python/tvm/relay/op/strategy/x86.py", line 142, in conv2d_strategy_cpu
    assert kernel_layout == "HWIO"
TVMError: AssertionError

cc @echuraev @Hzfengsy @tqchen

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jun 20, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@github-actions github-actions bot requested review from Hzfengsy, echuraev and tqchen June 20, 2023 05:46
@jikechao jikechao changed the title Fix depthwise conv2 d [Relay][x86] fix a wrong assertion about the kernal_layout of DepthwiseConv2D Jun 20, 2023
Copy link
Contributor

@echuraev echuraev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

Copy link
Contributor

@echuraev echuraev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I took a look a bit more carefully to the change, and it looks like it might break some scenarios. In this test, HWIO layout is used. Also, in the description of compute function is written that kernel layout should be in format HWIO: https://github.com/apache/tvm/blob/main/python/tvm/topi/nn/conv2d.py#L295-L296

Could you please double-check if your fix is correct?

@jikechao
Copy link
Contributor Author

jikechao commented Jun 20, 2023

@echuraev Thanks for your careful review. Indeed, this pr will lead to some test failure. I will correct the pr sooner.

Copy link
Contributor

@Aleksei-grovety Aleksei-grovety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a possible error in the keras.py? Since in TensorFlow depthwise_kernel_shape is defined as (H, W, I, O).

@jikechao jikechao changed the title [Relay][x86] fix a wrong assertion about the kernal_layout of DepthwiseConv2D [Frontend][Relay][Keras] fix a wrong assertion about the kernal_layout of DepthwiseConv2D Jun 22, 2023
@jikechao
Copy link
Contributor Author

jikechao commented Jul 7, 2023

@tvm-bot rerun

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants