Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📖 [Story] Windows Support #2495

Open
3 tasks done
gs-olive opened this issue Nov 28, 2023 · 3 comments
Open
3 tasks done

📖 [Story] Windows Support #2495

gs-olive opened this issue Nov 28, 2023 · 3 comments
Assignees
Labels
Story Issues proposing a new Story

Comments

@gs-olive
Copy link
Collaborator

gs-olive commented Nov 28, 2023

TL;DR

First iteration of Windows support task tracking

Tasks

Tasks

  1. bug
    gs-olive
  2. 0 of 1
    feature request
    gs-olive narendasan
  3. gs-olive
@gs-olive gs-olive added the Story Issues proposing a new Story label Nov 28, 2023
@gs-olive gs-olive self-assigned this Nov 28, 2023
@HolyWu
Copy link
Contributor

HolyWu commented Feb 27, 2024

I have managed to built the C++ and Python API with Bazel on Windows. While the torch_tensorrt package can be imported successfully and the _C extension does get loaded, the inference still has problems.

Firstly, I tried to run the below script using ir="ts" and it actually worked.

import torch
import torch_tensorrt

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.m = torch.nn.LeakyReLU()
        
    def forward(self, x):
        return self.m(x)
    
model = MyModule().eval().cuda()
inputs = [torch.randn((1, 3, 8, 8), dtype=torch.float, device="cuda")]

optimized_model = torch_tensorrt.compile(
    model,
    ir="ts",
    inputs=inputs,
    enabled_precisions={torch.float},
    debug=True,
    min_block_size=1,
)

print(optimized_model(*inputs))
torch._dynamo.reset()

Then, I ran the above script again with ir="dynamo" but it silently crashed without any message after the TRT engine had been built. I also tried adding output_format="fx" but no avail. At this point I have no clue what's the culprit.

The changes I have done are at https://github.com/HolyWu/TensorRT/tree/windows.

Here is the Python wheel if anyone wants to try it.
torch_tensorrt-2.3.0.dev0+afd5abebb-cp311-cp311-win_amd64.zip

@HolyWu
Copy link
Contributor

HolyWu commented Mar 2, 2024

I think I have found the culprit of the crash. The crash upon inference is caused by the call to torch.ops.tensorrt.execute_engine in the forward function of TorchTensorRTModule. It seems that the tensorrt.cp311-win_amd64.pyd extension in Windows TensorRT 8.6.1 has weird bug with torch.ops. Building torch_tensorrt with TensorRT 9.2.0.5 libraries and installing 9.2.0.5 (or 9.3.0.1) Python package eliminates the crash. Now the above script can successfully output the inference result from ir="dynamo", albeit there are a lot of Could not register plugin creator messages when torch_tensorrt is imported and I don't know why (maybe there are some API changes betwwen TRT 8 and TRT 9?).

[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::BatchedNMSDynamic_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::BatchedNMS_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::BatchTilePlugin_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::Clip_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::CoordConvAC version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::CropAndResizeDynamic version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::CropAndResize version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::DecodeBbox3DPlugin version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::DetectionLayer_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::EfficientNMS_Explicit_TF_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::EfficientNMS_Implicit_TF_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::EfficientNMS_ONNX_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::EfficientNMS_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::FlattenConcat_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::GenerateDetection_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::GridAnchor_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::GridAnchorRect_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::InstanceNormalization_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::InstanceNormalization_TRT version 2
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::LReLU_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::ModulatedDeformConv2d version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::MultilevelCropAndResize_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::MultilevelProposeROI_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::MultiscaleDeformableAttnPlugin_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::NMSDynamic_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::NMS_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::Normalize_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::PillarScatterPlugin version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::PriorBox_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::ProposalDynamic version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::ProposalLayer_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::Proposal version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::PyramidROIAlign_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::Region_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::Reorg_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::ResizeNearest_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::ROIAlign_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::RPROI_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::ScatterND version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::SpecialSlice_TRT version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::Split version 1
[03/02/2024-16:57:02] [TRT] [E] Could not register plugin creator -  ::VoxelGeneratorPlugin version 1
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%arg0_1, 0), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 0.01), kwargs = {})
    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%gt, %arg0_1, %mul), kwargs = {})
    return (where,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%arg0_1, 0), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 0.01), kwargs = {})
    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%gt, %arg0_1, %mul), kwargs = {})
    return (where,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%arg0_1, 0), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 0.01), kwargs = {})
    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%gt, %arg0_1, %mul), kwargs = {})
    return (where,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(precision=torch.float32, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.DEFAULT: 0>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, output_format='exported_program')

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.gt.Scalar + Operator Count: 1
- torch.ops.aten.mul.Tensor + Operator Count: 1
- torch.ops.aten.where.self + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 3 operators out of 3 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.gt.Scalar + Operator Count: 1
- torch.ops.aten.mul.Tensor + Operator Count: 1
- torch.ops.aten.where.self + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Submodule name: _run_on_acc_0
 Input shapes: [(1, 3, 8, 8)]
 graph():
    %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%arg0_1, 0), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 0.01), kwargs = {})
    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%gt, %arg0_1, %mul), kwargs = {})
    return where
[03/02/2024-16:57:03] [TRT] [I] [MemUsageChange] Init CUDA: CPU +4, GPU +0, now: CPU 13133, GPU 1009 (MiB)
[03/02/2024-16:57:03] [TRT] [V] Trying to load shared library nvinfer_builder_resource.dll
[03/02/2024-16:57:03] [TRT] [V] Loaded shared library nvinfer_builder_resource.dll
[03/02/2024-16:57:05] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +2731, GPU +310, now: CPU 16133, GPU 1319 (MiB)
[03/02/2024-16:57:05] [TRT] [V] CUDA lazy loading is enabled.
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name arg0_1
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name /m/gt
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name /m/mul
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name /m/where
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name output
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.001948
[03/02/2024-16:57:05] [TRT] [V] Original: 7 layers
[03/02/2024-16:57:05] [TRT] [V] After dead-layer removal: 7 layers
[03/02/2024-16:57:05] [TRT] [V] Graph construction completed in 0.0003945 seconds.
[03/02/2024-16:57:05] [TRT] [V] Running: ConstShuffleFusion on /m/gt_rhs
[03/02/2024-16:57:05] [TRT] [V] ConstShuffleFusion: Fusing /m/gt_rhs with /m/gt_rhs_broadcast
[03/02/2024-16:57:05] [TRT] [V] Running: ConstShuffleFusion on /m/mul_rhs
[03/02/2024-16:57:05] [TRT] [V] ConstShuffleFusion: Fusing /m/mul_rhs with /m/mul_rhs_broadcast
[03/02/2024-16:57:05] [TRT] [V] After Myelin optimization: 1 layers
[03/02/2024-16:57:05] [TRT] [V] Applying ScaleNodes fusions.
[03/02/2024-16:57:05] [TRT] [V] After scale fusion: 1 layers
[03/02/2024-16:57:05] [TRT] [V] After dupe layer removal: 1 layers
[03/02/2024-16:57:05] [TRT] [V] After final dead-layer removal: 1 layers
[03/02/2024-16:57:05] [TRT] [V] After tensor merging: 1 layers
[03/02/2024-16:57:05] [TRT] [V] After vertical fusions: 1 layers
[03/02/2024-16:57:05] [TRT] [V] After dupe layer removal: 1 layers
[03/02/2024-16:57:05] [TRT] [V] After final dead-layer removal: 1 layers
[03/02/2024-16:57:05] [TRT] [V] After tensor merging: 1 layers
[03/02/2024-16:57:05] [TRT] [V] After slice removal: 1 layers
[03/02/2024-16:57:05] [TRT] [V] After concat removal: 1 layers
[03/02/2024-16:57:05] [TRT] [V] Trying to split Reshape and strided tensor
[03/02/2024-16:57:05] [TRT] [V] Graph optimization time: 0.0003457 seconds.
[03/02/2024-16:57:05] [TRT] [V] Building graph using backend strategy 2
[03/02/2024-16:57:05] [TRT] [I] Global timing cache in use. Profiling results in this builder pass will be stored.
[03/02/2024-16:57:05] [TRT] [V] Constructing optimization profile number 0 [1/1].
[03/02/2024-16:57:05] [TRT] [V] Applying generic optimizations to the graph for inference.
[03/02/2024-16:57:05] [TRT] [V] Reserving memory for host IO tensors. Host: 0 bytes
[03/02/2024-16:57:05] [TRT] [V] =============== Computing costs for {ForeignNode[/m/gt_rhs + /m/gt_rhs_broadcast...[SELECT]-[aten_ops.where.self]-[/m/where_select]]}
[03/02/2024-16:57:05] [TRT] [V] *************** Autotuning format combination: Float(192,64,8,1) -> Float(192,64,8,1) ***************
[03/02/2024-16:57:05] [TRT] [V] --------------- Timing Runner: {ForeignNode[/m/gt_rhs + /m/gt_rhs_broadcast...[SELECT]-[aten_ops.where.self]-[/m/where_select]]} (Myelin[0x80000023])
[03/02/2024-16:57:05] [TRT] [V] [MemUsageChange] Subgraph create: CPU +0, GPU +0, now: CPU 16136, GPU 1319 (MiB)
[03/02/2024-16:57:05] [TRT] [V]  (foreignNode) Set user's cuda kernel library
[03/02/2024-16:57:05] [TRT] [V] Subgraph compilation completed in 0.141 seconds.
[03/02/2024-16:57:05] [TRT] [V] [MemUsageChange] Subgraph compilation: CPU +58, GPU +0, now: CPU 16194, GPU 1319 (MiB)
[03/02/2024-16:57:05] [TRT] [V] [runner] Allocating resources for 1 graphs.
[03/02/2024-16:57:05] [TRT] [V] Tactic: 0x0000000000000000 Time: 0.0064407
[03/02/2024-16:57:05] [TRT] [V] {ForeignNode[/m/gt_rhs + /m/gt_rhs_broadcast...[SELECT]-[aten_ops.where.self]-[/m/where_select]]} (Myelin[0x80000023]) profiling completed in 0.149506 seconds. Fastest Tactic: 0x0000000000000000 Time: 0.0064407
[03/02/2024-16:57:05] [TRT] [V] >>>>>>>>>>>>>>> Chose Runner Type: Myelin Tactic: 0x0000000000000000
[03/02/2024-16:57:05] [TRT] [V] =============== Computing reformatting costs for available format set
[03/02/2024-16:57:05] [TRT] [V] =============== Computing reformatting costs for available format set
[03/02/2024-16:57:05] [TRT] [V] Formats and tactics selection completed in 0.150114 seconds.
[03/02/2024-16:57:05] [TRT] [V] After reformat layers: 1 layers
[03/02/2024-16:57:05] [TRT] [V] Total number of blocks in pre-optimized block assignment: 1
[03/02/2024-16:57:05] [TRT] [I] Detected 1 inputs and 1 output network tensors.
[03/02/2024-16:57:05] [TRT] [V] Layer: {ForeignNode[/m/gt_rhs + /m/gt_rhs_broadcast...[SELECT]-[aten_ops.where.self]-[/m/where_select]]} Host Persistent: 32 Device Persistent: 0 Scratch Memory: 0
[03/02/2024-16:57:05] [TRT] [V] Skipped printing memory information for 0 layers with 0 memory size i.e. Host Persistent + Device Persistent + Scratch Memory == 0.
[03/02/2024-16:57:05] [TRT] [I] Total Host Persistent Memory: 32
[03/02/2024-16:57:05] [TRT] [I] Total Device Persistent Memory: 0
[03/02/2024-16:57:05] [TRT] [I] Total Scratch Memory: 0
[03/02/2024-16:57:05] [TRT] [V] Total number of blocks in optimized block assignment: 0
[03/02/2024-16:57:05] [TRT] [I] Total Activation Memory: 0
[03/02/2024-16:57:05] [TRT] [I] Total Weights Memory: 0
[03/02/2024-16:57:05] [TRT] [V] Total number of generated kernels selected for the engine: 0
[03/02/2024-16:57:05] [TRT] [V] Disabling unused tactic source: EDGE_MASK_CONVOLUTIONS
[03/02/2024-16:57:05] [TRT] [V] Disabling unused tactic source: JIT_CONVOLUTIONS
[03/02/2024-16:57:05] [TRT] [I] Engine generation completed in 0.152196 seconds.
[03/02/2024-16:57:05] [TRT] [V] Engine Layer Information:
Layer(Myelin): {ForeignNode[/m/gt_rhs + /m/gt_rhs_broadcast...[SELECT]-[aten_ops.where.self]-[/m/where_select]]}, Tactic: 0x0000000000000000, arg0_1 (Float[1,3,8,8]) -> output0 (Float[1,3,8,8])
[03/02/2024-16:57:05] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 1 MiB
[03/02/2024-16:57:05] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
[03/02/2024-16:57:05] [TRT] [V] Serializing timing cache. UUID = GPU-84374c83-93ab-0394-cd0d-66acc087082d, commit ID = d6cbd29d5253f99c
[03/02/2024-16:57:05] [TRT] [I] Serialized 52 bytes of code generator cache.
[03/02/2024-16:57:05] [TRT] [I] Serialized 8113 bytes of compilation cache.
[03/02/2024-16:57:05] [TRT] [I] Serialized 0 timing cache entries
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.153897
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 0 bytes of Memory
[03/02/2024-16:57:05] [TRT] [V] Adding 1 engine(s) to plan file.
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 3 Total Operators, of which 3 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(precision=torch.float32, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.DEFAULT: 0>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, output_format='exported_program')

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 8, 8)@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 3, 8, 8)@float32]
     Number of Operators in Engine: 3
     Engine Outputs: Tensor: (1, 3, 8, 8)@float32
    ...
   Outputs: List[Tensor: (1, 3, 8, 8)@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 3.0
   Most Operators in a TRT Engine: 3

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=3 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=3 which generates 1 TRT engine(s)
WARNING: [Torch-TensorRT] - Using default stream in enqueue()/enqueueV2()/enqueueV3() may lead to performance issues due to additional cudaDeviceSynchronize() calls by TensorRT to ensure correct synchronizations. Please use non-default stream instead.
C:\Python311\Lib\site-packages\torch\export\exported_program.py:740: UserWarning: Unable to execute the generated python source code from the graph. The graph module will no longer be directly callable, but you can still run the ExportedProgram, and if needed, you can run the graph module eagerly using torch.fx.Interpreter.
  warnings.warn(
(tensor([[[[ 6.2006e-01, -2.9064e-04,  9.8226e-01, -3.8333e-03, -2.1618e-02,
           -1.5738e-02,  2.6443e-01, -1.4374e-02],
          [-1.2791e-02, -1.1879e-02,  4.7439e-01,  6.9313e-01, -1.0766e-02,
            1.3728e+00,  8.3758e-01,  7.4307e-01],
          [-3.5098e-03,  4.5662e-01, -1.6441e-03, -1.5583e-02, -1.3701e-02,
            5.3097e-01, -3.2285e-03, -3.3189e-03],
          [-7.4308e-03, -1.6767e-02,  2.5757e-02,  4.5823e-01,  5.1807e-01,
            1.8787e-01, -1.3362e-02,  6.9791e-01],
          [-2.1611e-02,  1.0629e+00, -2.0338e-02, -1.7173e-03, -3.2053e-03,
            6.8387e-01, -1.0706e-02, -1.0811e-02],
          [-4.3017e-04,  1.4976e+00,  5.9703e-02, -7.9594e-03,  5.5089e-01,
           -8.1892e-03,  2.8588e-01, -1.5868e-02],
          [ 4.8508e-01,  1.1271e+00,  1.7785e+00, -7.4616e-03, -4.5771e-03,
           -1.4308e-02, -4.0018e-03,  3.5549e-01],
          [-5.8241e-03, -3.8064e-03, -5.9313e-04, -1.2990e-02,  3.8060e-01,
            6.4228e-03, -8.0151e-03, -1.0297e-03]],

         [[-7.0621e-03, -2.4972e-02,  2.4519e+00, -1.6451e-04,  4.1318e-01,
           -1.5970e-03,  1.8320e+00, -1.6313e-03],
          [-1.6109e-03,  2.8662e-01,  1.7991e+00, -2.8926e-04, -4.4925e-03,
            1.9756e+00, -1.4357e-03, -2.6023e-02],
          [-1.0665e-02, -1.0338e-03, -8.8272e-03,  1.4398e-01, -4.2988e-03,
           -8.1295e-03,  4.7497e-01, -6.1760e-03],
          [ 4.0540e-01, -1.6224e-02, -6.9385e-04,  6.9263e-01, -1.7464e-03,
            2.4541e-01, -1.2534e-02,  1.0720e+00],
          [-1.0139e-03, -1.1172e-02,  6.7257e-01, -3.4479e-03,  1.3535e+00,
            6.0280e-01, -1.0561e-02,  1.7536e-01],
          [ 2.7637e-01, -6.3729e-03, -4.7551e-03,  9.1582e-02,  5.9293e-01,
           -3.4702e-03, -4.8539e-03,  7.7745e-01],
          [ 4.6169e-01,  1.1525e+00,  4.7898e-02,  5.5210e-01,  6.2328e-01,
            4.1276e-01,  5.5681e-01, -8.6993e-03],
          [-9.9063e-03, -5.2969e-03, -2.0014e-02, -1.3458e-02,  9.3519e-01,
            1.8661e+00,  8.6441e-02, -9.3390e-03]],

         [[ 6.8911e-01, -1.4560e-02,  8.1485e-01,  9.2810e-02, -5.7742e-04,
           -2.2630e-02, -5.7046e-03, -9.1001e-03],
          [-9.7138e-03,  1.0972e+00, -3.4715e-03, -3.4034e-03, -9.6438e-03,
           -4.8049e-04,  1.0703e+00, -1.5452e-02],
          [-9.4251e-03, -1.3537e-03, -1.6299e-02, -9.5506e-03, -5.3317e-04,
           -7.3883e-03,  4.1210e-03, -6.1995e-03],
          [ 3.0561e-02, -7.9463e-03,  5.3772e-01,  1.0653e+00,  1.6256e+00,
            6.9731e-01,  6.4045e-02,  1.5135e+00],
          [ 1.7414e+00,  8.4223e-01,  2.9087e-02,  1.7046e-02,  6.8737e-01,
            2.9057e-01, -1.3581e-03, -1.9831e-02],
          [ 3.6773e-01,  8.3533e-01,  6.4043e-03, -5.6911e-03,  1.5034e-01,
            1.0129e+00,  1.6859e-01,  2.0443e+00],
          [-1.6103e-02,  1.7941e+00,  6.6912e-01, -3.1686e-04,  8.3015e-01,
           -5.2209e-03, -5.9463e-03,  1.7722e+00],
          [-3.0500e-03, -3.5577e-03, -7.3850e-03,  3.5458e-02,  4.7506e-01,
            1.0233e+00,  8.1644e-01, -1.4635e-02]]]], device='cuda:0'),)

Here is the Python wheel built with TensorRT 9.2.0.5.
torch_tensorrt-2.3.0.dev0+741ec5b6b-cp311-cp311-win_amd64.zip

@gs-olive
Copy link
Collaborator Author

gs-olive commented Mar 4, 2024

@HolyWu - thank you very much for this information - this is very helpful. I will take a look at your branch as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Story Issues proposing a new Story
Projects
None yet
Development

No branches or pull requests

2 participants