Skip to content
This repository has been archived by the owner on Feb 1, 2020. It is now read-only.

[REVIEW] Tensorflow frontend support #472

Closed
wants to merge 7 commits into from

Conversation

srkreddy1238
Copy link
Member

Inception V1 supported now.

@srkreddy1238 srkreddy1238 force-pushed the master branch 6 times, most recently from 716ff6a to d7b21bc Compare May 9, 2018 18:15
@srkreddy1238
Copy link
Member Author

@tqchen review pls

@tqchen
Copy link
Member

tqchen commented May 13, 2018

@Huyuwei can you help do a pass of review?

@@ -5,3 +5,4 @@
from .coreml import from_coreml
from .keras import from_keras
from .darknet import from_darknet
from .tensorflow import from_tensorflow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

frontend should always be optional, this means unless from_tensorflow is called, we should not depend on availability of tensorflow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls refer to the implementation of from_mxnet etc to see how do they import the dependency within the function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,512 @@
# pylint: disable=import-self, invalid-name, unused-argument, too-many-nested-blocks, no-else-return, line-too-long
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not disable long-to-long, no-else-return too-many-nested-blocks,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@tqchen
Copy link
Member

tqchen commented May 13, 2018

@FrozenGene would you be interested in review the code?

@Huyuwei
Copy link
Member

Huyuwei commented May 13, 2018

@tqchen Sure, will do a review this weekend.

Copy link

@FrozenGene FrozenGene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we don't have 4 elements(top, left, bottom, right) for padding, For example, the input_height is 300, input_width is 300, stride is 2, kernel is 3, padding is SAME. The tensorflow guide is here: https://www.tensorflow.org/api_guides/python/nn#Convolution I follow the guide:

out_height = ceil(float(in_height) / float(strides[1]))   # ceil(300 / 2) ===> 150
if (in_height % strides[1] == 0):    # 300 % 2
  pad_along_height = max(filter_height - strides[1], 0) # 1
else:
  pad_along_height = max(filter_height - (in_height % strides[1]), 0)

pad_top = pad_along_height // 2    # 1 // 2 ===> 0
pad_bottom = pad_along_height - pad_top # ===> 1 - 0 = 1

Your calcalution:

def _get_pad(input1d, kernel1d, stride1d):
    out1d = (input1d + stride1d - 1) // stride1d    # 150
    pad = np.maximum((out1d - 1) * stride1d + kernel1d - input1d, 0) # 1
    pad = pad // 2 ===> pad = 0
    return pad    # JUST top / left

our NNVM implementation:

 # compute the padding size
    if isinstance(padding, (tuple, list)):
        pad_h = padding[0] * 2 # => 0
        pad_w = padding[1] * 2
    #...
    pad_top = (pad_h + 1) // 2 # 0
    pad_left = (pad_w + 1) // 2
   # pad_top is 0, pad_bottom = 0  - 0 = 0, correct answer is 1 
   return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left

So if we don't pass 4 elements to nnvm, we can not have the correct implementation for SAME padding. I know it is not your problem, but I think it is one chance to point it out.

@FrozenGene
Copy link

FrozenGene commented May 13, 2018

@tqchen Thanks for invitation. I am doing the padding support as we discussed in the forum(https://discuss.tvmlang.org/t/tvm-seems-conv2d-has-relate-big-floating-point-error-when-to-from-coreml/145/13). My implementation of Conv2D and depthwise2D has completely exclude scheduling part. And we can have the same result as Tensorflow, CoreML now. Previously, we can have the correct predict result in MobileNet, but have different output like Tensorflow / CoreML when the padding is SAME in MobileNet. Now, the problem is solved. My test model is MobileNet, SSD MobileNet, the model files is from tensorflow official repo and convert into CoreML format, then use NNVM to support it. Currently, I just review the padding part, but I find this implementation is also not contain (top, left, bottom, right) 4 elements, so I think this implementation should also have the same problem of padding.

return _dim_check, "Only 2d kernel supported."

def _infer_channels(inputs, params, transpose=False):
"""A hack for getting 'channles' or 'units' since onnx don't provide
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not onnx


if attr['padding'] == 'VALID':
attr['padding'] = [0, 0]
elif attr['padding'] == 'SAME':
Copy link
Member

@Huyuwei Huyuwei May 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid SAME padding is not correctly converted here, as pointed out by @FrozenGene.
Can you add a test with input_h=32, kernel_h=3, stride_h=2?

You can use nnvm.sym.pad to support SAME padding as in keras frontend: https://github.com/dmlc/nnvm/blob/master/python/nnvm/frontend/keras.py#L153

Copy link

@FrozenGene FrozenGene May 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Huyuwei Aha, this keras's implementation is interesting. But how to we handle the workload / schedule part? Because in our workload / schedule part does only concern VALID padding and only care top / left. see: https://github.com/dmlc/tvm/blob/9175b47c16691824b1a54aa26dc1ddb21c981612/topi/python/topi/nn/conv2d.py#L146 Keras's implementation solve SAME padding, but not solve pass this padding information to the scheduling part. This is why I discuss with @tqchen in https://discuss.tvmlang.org/t/tvm-seems-conv2d-has-relate-big-floating-point-error-when-to-from-coreml/145. @tqchen suggest extending Conv2D's operator's param(maybe include pooling) form [top, left] to [top, left, bottom, right]. So I want to know this Keras's implementation. how to pass the 4 padding information to scheduling? If we can solve this problem, we don't need to extend operator's param.

Copy link
Member Author

@srkreddy1238 srkreddy1238 May 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I borrowed the pad_pair calculation from keras which I replaced with tensorflow reference now.
Both work fine with InceptionV1 as long as I don't use pad operator before convolution, but pass pad_t, pad_l to conv2d.

I see the below error with pad operator. I am investigating it.

nnvm._base.NNVMError: Error in operator conv2d2: [14:09:41] src/top/nn/convolution.cc:84: Operator conv2d(use_bias=False, strides=(1, 1), channels=64, layout=NHWC, padding=[0, 0], kernel_size=(3, 3), kernel_layout=HWIO, name=conv2d2) expects weight's shape to be [3,3,34,64], but got [3,3,32,64].

Copy link
Member Author

@srkreddy1238 srkreddy1238 May 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above error due to different order of padding values for NHWC, which I corrected.

Now the padding scheme for TF front end follows tensorflow reference with pad operator before convolution and it work fine with InceptionV1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FrozenGene Now keras frontend inserts a pad operator before conv to handle SAME padding, and I expect the pad operator to be fused. Consider this sequence: extra_pad_1, pad_1, conv_1, relu_1, extra_pad_2, pad_2, conv2... relu_1 and extra_pad_2 will all be fused into conv_1. But I just noticed that extra_pad_2 won't be fused in the generated code. So this pad operator brings overhead.

I agree that we can extend conv2d's padding param to support [top, left, bottom, right]. It will make the conversion much easier and clearer. The pad API in keras may be helpful: https://keras.io/layers/convolutional/#zeropadding2d

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another issue I came across with padding on pool operation while doing test cases for tensorflow frontend https://github.com/dmlc/nnvm/issues/486

I think we could move this topic to discussion forum.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srkreddy1238 @Huyuwei I have extended padding completely and can have the same result as Tensorflow / CoreML now. However, I keep the implementation of padding [top, left], which is useful for the compatibility of old code. For example, we have unittest test_infer_shape using old behaviour. But if user specify [top, left, bottom, right], we go to the new behaviour. Another reason I keep the implementation is I find many framework(keras, MXNet and so on) are using padding[top, left], I don't have time to test every framework(Especially MXNet, I find its official documentation says that his padding parameter is only[top, left] too). But I think we should only have padding[top, left, bottom, right]. Do you have any ideas or suggestions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Can you PR it to review ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srkreddy1238 After I finish my busy time and add more unittests, I will PR it with CoreML's support. I have tested MobileNet / Resnet50 / SSD-MobileNet and can work fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 . I will help to check tensorflow once you PR it.

# Rearrange inputs from
# (data, moving_mean, moving_variance, beta, gamma) to
# to
# (data, gamma, beta, moving_mean, moving_var)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to to

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

'BatchNormWithGlobalNormalization' : _batch_norm(),
'BiasAdd' : _bias_add(),
'Cast' : _cast(),
'CheckNumerics' : _check_numerics(), # TODO
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO here is confusing. Maybe comment out unsupported operators?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the TODO with one section explaining the implementation, limitation and assumption aspects.



class GraphProto(object):
""" TODO: A helper class for handling nnvm graph copying from Tensorflow GraphDef.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO here is confusing, since GraphProto is implemented.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the TODO with one section explaining the implementation, limitation and assumption aspects.


def from_tensorflow(self, graph):
"""Construct nnvm nodes from tensor flow graph definition - GraphDef.
TODO: Detailed explanation of TF GraphDef parsing.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the explanation and remove TODO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the TODO with one section explaining the implementation, limitation and assumption aspects.

for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can compare TVM output with tensorflow output to show the correctness.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I will add tensorflow prediction logic and do a comparision.

Also working on the test cases for TF.

@Huyuwei
Copy link
Member

Huyuwei commented May 14, 2018

@srkreddy1238 Tests should be added.

And many TODO are quite confusing. I think it's helpful to add some comments in the file to describe what is not yet supported in the tensorflow converter.

@srkreddy1238 srkreddy1238 force-pushed the master branch 6 times, most recently from fe837f8 to dfc234a Compare May 14, 2018 16:37
@srkreddy1238
Copy link
Member Author

Done with all review comments listed below.
Test cases added.
Pad operator similar to keras.
Corrections on internal documentation and coding standards.

Left:
Padding discussion, which is common for multiple frameworks.

@tqchen
Copy link
Member

tqchen commented May 16, 2018

wait @FrozenGene @Huyuwei 's updated comments or approval

@srkreddy1238
Copy link
Member Author

@FrozenGene & @Huyuwei any comments ??

@FrozenGene
Copy link

Sorry for replying later. Just as @srkreddy1238 said, we left padding to discuss later. I will PR it after I finish this busy time and add unittests. Before PR, I also want to listen to your comments(Whether we should keep old behaviour's compatibility?)

For the code part: https://github.com/srkreddy1238/nnvm/blob/e68e04ae1c3931f73488991f188ba8ed633c30b3/python/nnvm/frontend/tensorflow.py#L453

                # Assuming only one output.
                self._nodes[node.name] = op
                node_output = op
        # Assume the final node is the output node
        out = node_output

Could you help to explain why we don't handle multiple output? One example of multiple output is: https://github.com/tensorflow/models/tree/master/research/object_detection. SSD-Mobilenet's core feature extractor will have two outputs: 'concat:0', 'concat_1:0'. I think we can do it like ONNX's way: https://github.com/dmlc/nnvm/blob/master/python/nnvm/frontend/onnx.py#L585

https://github.com/srkreddy1238/nnvm/blob/e68e04ae1c3931f73488991f188ba8ed633c30b3/python/nnvm/frontend/tensorflow.py#L249

    def _impl(inputs, attr, params):
        # Making a copy node assuming the input image shape is 299x299
        # Change this when we have corresponding resize bilinear operation.
        pop_node = inputs.pop(1)
        params.pop(pop_node.list_output_names()[0])
        return AttrCvt(op_name="copy", ignores=['align_corners'])(inputs, attr)
    return _impl

This way I think it is not a good way to handle. Why we assume it is 299x299 and just copy it simply? And if we have others models using it and has wrong output (For example, deeplab v3 model also uses ResizeBilinear ops), it is difficult to debug. (Runtime's wrong output is difficult). One candidate is tf-coreml's way: https://github.com/tf-coreml/tf-coreml/blob/master/tfcoreml/_layers.py#L832 We have upsampling now, but our upsampling only supports nearest neighbors mode, we are lack of bilinear mode. If you add bilinear mode for upsamping, we should support it correctly.

For the _convert_map, you implement limited ops, if you can do like the ONNX's way(comment out unimplemented ops), I think it is better, which can make others know what ops we don't implement now. But I know this is a boring work, it is up to you deciding whether to do it.

@srkreddy1238
Copy link
Member Author

@FrozenGene

(Whether we should keep old behaviour's compatibility?) : I feel we should keep the compatibility and ensure the support for all frameworks.

The front end for tensorflow is the initial work with minimal support which need to be evolved over time inline with TVM/NNVM evolving.

Here answering your reviews.

Output nodes: Unlike other frameworks tensorflow protobuf I don't see a definition for inputs and output nodes specifically, hence assumed first node as input and last node as output. I have this multiple output support in mind, I am working on it.

BilinearResize is dummy as there is no supported topi. This can be fixed once we have the topi support in near future. I will correct the comment "299x299".

Tensorflow operator list is too huge compared to other frameworks, so I have choose to list supported list. I have plans to create categorized lists in coming PR for tensorflow frontend.

This is initial effort towards supporting tensorflow, support I am already listing on the missing pieces which include above issues to bring up with a plan.

@FrozenGene
Copy link

yes, we don't have BilinearResize op. I list one ref link that CoreML doesn't have BilinearResize op too, but it uses upsampling to do the work. So, I think if we don't BilinearResize op to done the mapping, we should do the work like CoreML's way, not just copy, which is not correct obviously.

@srkreddy1238
Copy link
Member Author

I see CoreML approach, but NNVM don't have operator to support any resize now.

DecodeJpeg is removed completely as we decided not to support decode in NNVM in anyway. I kept ResizeBilinear with a copy now assuming we come up with a topi for resize operator soon.

I will add a notify print in frontend to notify user about ResizeBilinear dummy functionality until we don't have correct implementation in place.

Feel free to comment on this.

@FrozenGene
Copy link

We have UpSampling operator but don't have Bilinear mode. If you want to do like CoreML's approach, you need to add Bilinear mode's support : https://github.com/dmlc/tvm/blob/785c9420c1604dbaf85195c31be625f496e5776f/topi/python/topi/nn/upsampling.py Keras's implementation have also used this op: https://github.com/dmlc/nnvm/blob/master/python/nnvm/frontend/keras.py#L289

Or you should use some information to notify user as you said.

@srkreddy1238
Copy link
Member Author

Great, I will update to use this for time being until we have bilinear.

@srkreddy1238
Copy link
Member Author

srkreddy1238 commented May 21, 2018

Yes, different scaling to add more flexibility.

Consolidated the resize discussion below to followup
https://discuss.tvmlang.org/t/scaling-design-for-nn-nearest-neighbour-and-bilinear/187

@srkreddy1238 srkreddy1238 force-pushed the master branch 2 times, most recently from 0d0569d to 55dc1e6 Compare May 22, 2018 01:49
@srkreddy1238
Copy link
Member Author

test_graph.test_print_graph_ir ... ok
test_graph.test_gradient ... ./tests/scripts/task_python_test.sh: line 6: 13532 Segmentation fault (core dumped) python -m nose -v tests/python/unittest
script returned exit code 255

@tqchen need your help here :)

@srkreddy1238
Copy link
Member Author

@tqchen , @Huyuwei & @FrozenGene
Added more ops to frontend for Inception V3.
Help me with a final review.

@tqchen : Please help with the CI issue :)

Follow the tensorflow graph definition to parse and convert it to NNVM.
Some of the assumptions listed below.

-> First Const node will be comsidered as graph input.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comsidered -> considered

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@srkreddy1238
Copy link
Member Author

srkreddy1238 commented May 23, 2018

CI look good now after rebase with nnvm/master :)

@srkreddy1238
Copy link
Member Author

@tqchen , @Huyuwei Any other review comments ?

@srkreddy1238 srkreddy1238 changed the title Tensorflow frontend support [REVIEW] Tensorflow frontend support May 24, 2018
@srkreddy1238
Copy link
Member Author

srkreddy1238 commented May 26, 2018

@tqchen and @Huyuwei Can you guys spare some time on this ?

@@ -0,0 +1,315 @@
# pylint: disable=import-self, invalid-name, unused-argument
"""
Tensorflow testcases
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test code looks not neat.

Could you remove some functions like _download(), since we don't need a real image for testing, just use random numbers.

get_shrunk_inception_shapes() seems messy. Do we really need to test all these conv settings? I prefer the keras frontend testing, you may have a look: https://github.com/dmlc/nnvm/blob/master/tests/python/frontend/keras/test_forward.py

@srkreddy1238
Copy link
Member Author

Sure, I will replace real image test case by operator specific cases by random input data.

Convolution possibilities are inherited from tensor flow test cases. I feel good to have more possibilities covered.

What do you suggest ?

srkreddy1238 and others added 6 commits May 27, 2018 19:35
    optional tensorflow depedency.
    do not disable long-to-long, no-else-return too-many-nested-blocks
    pad pair calculation as described by tensorflow.
    Internal documentation corrections.
    Indentation corrections.
        Inception v1
        Convolution
    ConcatV2, Rsqrt, Add, Squeeze
@Huyuwei
Copy link
Member

Huyuwei commented May 27, 2018

@srkreddy1238 It's better to cover more operators and more network architectures, like other frontend tests. Other people may want to improve the tensorflow frontend and add unit tests, so please clean the testing code to make that easy.

@srkreddy1238
Copy link
Member Author

@Huyuwei Updated test cases. Pls review.

@tqchen
Copy link
Member

tqchen commented May 29, 2018

c.f. #518 we will redirect further changes to tvm repo, please open a new PR there. Please invite the original set of reviewers when the new PR is opened so they can review and approve the changes

@tqchen tqchen closed this May 29, 2018
abergeron pushed a commit to abergeron/nnvm that referenced this pull request May 31, 2018
* migrate global_avg_pool, fully_connected

* fix pylint

* enable fusion of pooling schedule

* rename fc->dense, enable fusion

* improve dense schedule

* unified global pool
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants