Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix upsample converter not properly registered #2683

Merged
merged 1 commit into from
Apr 19, 2024

Conversation

HolyWu
Copy link
Contributor

@HolyWu HolyWu commented Mar 10, 2024

Description

Partially #2665

Even though the operator is properly registered along with #2681 being applied, the operator is still decomposed into lower-level operators rather than converted using this converter, just like #2665 (comment). Adding aten.upsample_bilinear2d.default and aten.upsample_bilinear2d.vec to torch_disabled_decompositions doesn't help. Compiling the model under with torch.inference_mode() also doesn't help. At the end I find out that I have to remove these two lines and this line in PyTorch to bypass the decomposition and then this converter finally works.

DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.default](args = (%arg0_1, [256, 256], True, 2.0, 2.0), kwargs = {})
    return (upsample_bilinear2d,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.default](args = (%arg0_1, [256, 256], True, 2.0, 2.0), kwargs = {})
    return (upsample_bilinear2d,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.default](args = (%arg0_1, [256, 256], True, 2.0, 2.0), kwargs = {})
    return (upsample_bilinear2d,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(precision=torch.float16, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.DEFAULT: 0>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, output_format='exported_program')

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.upsample_bilinear2d.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.upsample_bilinear2d.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Submodule name: _run_on_acc_0
 Input shapes: [(1, 3, 128, 128)]
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.default](args = (%arg0_1, [256, 256], True, 2.0, 2.0), kwargs = {})
    return upsample_bilinear2d
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name arg0_1
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name __/upsample_bilinear2d
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name output
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.000980
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:01.093642
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 0 bytes of Memory
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(precision=torch.float16, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.DEFAULT: 0>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, output_format='exported_program')

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 128, 128)@float16]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 3, 128, 128)@float16]
     Number of Operators in Engine: 1
     Engine Outputs: Tensor: (1, 3, 256, 256)@float16
    ...
   Outputs: List[Tensor: (1, 3, 256, 256)@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 1.0
   Most Operators in a TRT Engine: 1

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
WARNING: [Torch-TensorRT] - Using default stream in enqueue()/enqueueV2()/enqueueV3() may lead to performance issues due to additional cudaDeviceSynchronize() calls by TensorRT to ensure correct synchronizations. Please use non-default stream instead.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Mar 10, 2024
@github-actions github-actions bot requested a review from peri044 March 10, 2024 09:14
@apbose
Copy link
Collaborator

apbose commented Mar 12, 2024

Thanks for the analysis and pointing out the above!

I looked at it and looks like in the above case the AOT trace is returning the decomposition for torch.nn.functional.interpolate Pre-AOT trace

graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %interpolate : [num_users=1] = call_function[target=torch.nn.functional.interpolate](args = (%clone_default,), kwargs = {size: None, scale_factor: 2, mode
: bilinear, align_corners: True, recompute_scale_factor: None, antialias: False})
    return (interpolate,)

Post the torch.export or the AOT trace the graph decomposes into a big graph

    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
    %_to_copy : [num_users=4] = call_function[target=torch.ops.aten._to_copy.default](args = (%clone,), kwargs = {dtype: torch.float32})
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, 256), kwargs = {dtype: torch.float32, layout: torch.strided, $
evice: cuda:0, pin_memory: False})
    %arange_1 : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, 256), kwargs = {dtype: torch.float32, layout: torch.strided$
 device: cuda:0, pin_memory: False})
    %mul : [num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arange, 0.4980392156862745), kwargs = {})
    %mul_1 : [num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arange_1, 0.4980392156862745), kwargs = {})
    %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%mul,), kwargs = {dtype: torch.int64})
    %ceil : [num_users=1] = call_function[target=torch.ops.aten.ceil.default](args = (%mul,), kwargs = {})
    %clamp : [num_users=1] = call_function[target=torch.ops.aten.clamp.default](args = (%ceil, None, 127), kwargs = {})
    %_to_copy_2 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%clamp,), kwargs = {dtype: torch.int64})
    %_to_copy_3 : [num_users=3] = call_function[target=torch.ops.aten._to_copy.default](args = (%mul_1,), kwargs = {dtype: torch.int64})
    %ceil_1 : [num_users=1] = call_function[target=torch.ops.aten.ceil.default](args = (%mul_1,), kwargs = {})
    %clamp_1 : [num_users=1] = call_function[target=torch.ops.aten.clamp.default](args = (%ceil_1, None, 127), kwargs = {})
    %_to_copy_4 : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%clamp_1,), kwargs = {dtype: torch.int64})
    %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%mul, 1), kwargs = {})
    %unsqueeze_1 : [num_users=3] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%_to_copy_1, 1), kwargs = {})
    %unsqueeze_2 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%_to_copy_2, 1), kwargs = {})
    %index : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%_to_copy, [None, None, %unsqueeze_1, %_to_copy_3]), kwargs = {})
    %index_1 : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%_to_copy, [None, None, %unsqueeze_2, %_to_copy_3]), kwargs = {})
    %index_2 : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%_to_copy, [None, None, %unsqueeze_1, %_to_copy_4]), kwargs = {})
    %index_3 : [num_users=1] = call_function[target=torch.ops.aten.index.Tensor](args = (%_to_copy, [None, None, %unsqueeze_2, %_to_copy_4]), kwargs = {})
    %sub : [num_users=3] = call_function[target=torch.ops.aten.sub.Tensor](args = (%unsqueeze, %unsqueeze_1), kwargs = {})
    %sub_1 : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (1.0, %sub), kwargs = {})
    %sub_2 : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_1, %_to_copy_3), kwargs = {})
    %sub_3 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (1.0, %sub_2), kwargs = {})
    %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index, %sub_1), kwargs = {})
    %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index_1, %sub), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_2, %mul_3), kwargs = {})
    %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index_2, %sub_1), kwargs = {})
    %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%index_3, %sub), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
    %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %sub_3), kwargs = {})
    %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_1, %sub_2), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {})
    %_to_copy_5 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add_2,), kwargs = {dtype: torch.float16})
    return (_to_copy_5,)

As far as I understand the torch._decomp.get_decomposition decomposition is not responsible for this decomposition. (That is the reason removing this does not make any difference). As pointed out by you the above decomp in torch is leading to this, but I am not sure if we want to do away with that decomposition by changing the torch code. Couple of things

  1. I think @dheerajperi and @gs-olive would know better to comment on the above
  2. Meanwhile after the decomposition I see the index.Tensor converter throwing an error in the above case. I am looking at the error to debug

@gs-olive
Copy link
Collaborator

@apbose - does the decomposition into that large set of operators you showed still occur if we remove the following two lines (but don't add anything to torch_disabled_decompositions)?

aten.upsample_bilinear2d,
aten.upsample_bilinear2d.vec,

@apbose
Copy link
Collaborator

apbose commented Mar 14, 2024

@gs-olive, yes the above operation decomposes into the large set of ops when the two lines shown above has been commented.
Also a side doubt, when you say that they should not be added to torch_disabled_decompositions, ideally I get why you said that since that would be an unnecessary addition I suppose? Since torch_disabled_decompositions would be effective only when enable_experimental_decompositions would be True where all the core_aten_decompostions decompositions would be considered except the disabled ones? And in this case the above param is False.

@narendasan narendasan merged commit dde535e into pytorch:main Apr 19, 2024
16 of 21 checks passed
@HolyWu HolyWu deleted the upsample branch April 19, 2024 04:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants