Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add pass for getting calibration data from a relay module #5997

Merged
merged 27 commits into from
Jul 13, 2020

Conversation

seanlatias
Copy link
Contributor

RFC discussion: https://discuss.tvm.ai/t/rfc-byoc-data-calibration-flow/7099/15

This PR implements the analysis pass get_calibration_data mentioned in the RFC. The main functionality of this analysis pass is allowing users to easily get the calibration data from any Relay module. The calibration data includes the input and output tensor values of each subgraph in the module. Following is an example.

The input relay graph that contains two subgraphs:

def @dnnl0(%dnnl0_i0: Tensor[(3, 3), float32], %dnnl0_i1: Tensor[(3, 3), float32]) -> Tensor[(3, 3), float32] {
  add(%dnnl0_i0, dnnl0_i1) 
}

def @dnnl1(%dnnl0_i0: Tensor[(3, 3), float32], %dnnl0_i1: Tensor[(3, 3), float32]) -> Tensor[(3, 3), float32] {
  sub(%dnnl0_i0, dnnl0_i1) 
}

def @main(%data0: Tensor[(3, 3), float32], %data1: Tensor[(3, 3), float32], %data2: Tensor[(3, 3), float32]) -> Tensor[(3, 3), float32] {
  %0 = @dnnl0(%data0, %data1)
  @dnnl1(%0, %data2)
}

The Python API

mod = # the above relay graph
data = {"data0": ..., "data1": ..., "data2": ...}

calib_data = relay.analysis.get_calibration_data(mod, data)
print(calib_data)

The expected output

{@dnnl0: {"inputs": [%data0, %data1], "outputs": [%0]},
 @dnnl1: {"inputs": [%0, %data2], "outputs":[%out]}}

As can be seen, the output calibration data is a two-level dictionary. The first level takes in the GlobalVar of a subgraph as a key and the value is the second-level dictionary. The second level always has two keys: inputs and outputs, which map to the real tensor values of each subgraph. For more complex example please refer to the test.

@comaniac @zhiics

python/tvm/relay/analysis/analysis.py Show resolved Hide resolved
python/tvm/relay/analysis/analysis.py Show resolved Hide resolved
python/tvm/relay/analysis/analysis.py Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
tests/python/relay/test_analysis_get_calibration_data.py Outdated Show resolved Hide resolved
def get_calibration_data(mod, data):
"""Get the calibration data of a given relay graph

This pass use the graph runtime to get the calibration data of a module, which
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to think of the semantic for the module with control flows, which has to use VM instead of graph runtime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This pass use the graph runtime to get the calibration data of a module, which
This pass uses the graph runtime to get the calibration data of a module, which

Per offline discussion, please mention that this pass only works for the graph without control flow.

src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
def get_calibration_data(mod, data):
"""Get the calibration data of a given relay graph

This pass use the graph runtime to get the calibration data of a module, which
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This pass use the graph runtime to get the calibration data of a module, which
This pass uses the graph runtime to get the calibration data of a module, which

Per offline discussion, please mention that this pass only works for the graph without control flow.

src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
src/relay/analysis/get_calibration_data.cc Outdated Show resolved Hide resolved
Copy link
Member

@zhiics zhiics left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiics zhiics merged commit 96fe315 into apache:master Jul 13, 2020
@zhiics
Copy link
Member

zhiics commented Jul 13, 2020

Thanks @seanlatias @comaniac

@seanlatias seanlatias deleted the calibrate branch July 14, 2020 11:35
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jul 14, 2020
…ache#5997)

* add simple pass to extract outputs

* complete pass that collects all function inputs/outputs

* add analysis pass for collecting outputs

* reorganize the files

* add the first test

* update test with tuples

* clean up Python code

* merge with upstream

* clean up transform.py

* add comments for cpp files

* fix lint issues

* update submodules

* modify files according to the review

* fix style and typo

* fix lint error

* add checks for repeated function calls

* fix lint error

* merge review comments

* small simplification

* revise the code according to the review comments

* add username in TODO

* use IRModule directly

* use better APIs according to the review

* apply comments from the reviewer

* retrigger ci
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jul 14, 2020
…ache#5997)

* add simple pass to extract outputs

* complete pass that collects all function inputs/outputs

* add analysis pass for collecting outputs

* reorganize the files

* add the first test

* update test with tuples

* clean up Python code

* merge with upstream

* clean up transform.py

* add comments for cpp files

* fix lint issues

* update submodules

* modify files according to the review

* fix style and typo

* fix lint error

* add checks for repeated function calls

* fix lint error

* merge review comments

* small simplification

* revise the code according to the review comments

* add username in TODO

* use IRModule directly

* use better APIs according to the review

* apply comments from the reviewer

* retrigger ci
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants