Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numpy changes for aten::index converter #2396

Merged
merged 6 commits into from
Nov 8, 2023
Merged

Numpy changes for aten::index converter #2396

merged 6 commits into from
Nov 8, 2023

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Oct 13, 2023

This PR addresses the changes for the cases where indices are list of numpy arrays or numpy for aten::index. #2394

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Oct 13, 2023
@github-actions github-actions bot requested a review from peri044 October 13, 2023 16:23
@apbose apbose marked this pull request as draft October 13, 2023 16:24
@apbose apbose removed the request for review from peri044 October 13, 2023 16:24
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2023-10-13 16:23:31.589620+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2023-10-13 16:26:06.957081+00:00
@@ -178,12 +178,11 @@
        rhs_val = cast_trt_tensor(ctx, rhs_val, trt.float32, name)
    return [lhs_val, rhs_val]


def broadcastable(
-    a: Union[TRTTensor, np.ndarray],
-    b: Union[TRTTensor, np.ndarray]
+    a: Union[TRTTensor, np.ndarray], b: Union[TRTTensor, np.ndarray]
) -> bool:
    "Check if two tensors are broadcastable according to torch rules"
    a_shape = tuple(a.shape)
    b_shape = tuple(b.shape)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@apbose apbose self-assigned this Oct 16, 2023
@apbose apbose requested a review from gs-olive October 16, 2023 22:36
@apbose apbose marked this pull request as ready for review October 16, 2023 22:37
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@@ -127,6 +140,7 @@ def index(
rank = len(input_shape)
adv_indx_count = len(adv_indx_indices)
dim_tensor_list = []
dim_list = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is unused, dim_list = [] can be removed

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pending the above comment, it looks good to me!

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works well on the SD usecase, and tests pass locally - could a test case be added specifically for this addition, which tests the functionality when all inputs are constants?

Copy link
Collaborator Author

@apbose apbose left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gs-olive. Regarding adding of the test for this feature, the earlier tests were all numpy indices, where is_numpy is set as True. So those tests should be good for this. Is there any way so that tests can be added so that the indices are ITensors?

@gs-olive
Copy link
Collaborator

Yes, there is. You would just need to pass the indices as inputs to the forward function, instead of having them as constants in the model/graph

@github-actions github-actions bot added the component: tests Issues re: Tests label Oct 30, 2023
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py	2023-10-30 19:32:26.833431+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py	2023-10-30 19:35:28.665002+00:00
@@ -24,11 +24,11 @@
        input = [torch.randn(2, 2)]
        self.run_test(
            TestModule(),
            input,
        )
-    
+
    def test_index_zero_two_dim_ITensor(self):
        class TestModule(nn.Module):
            def forward(self, x, index0):
                indices = [None, index0]
                out = torch.ops.aten.index.Tensor(x, indices)
@@ -56,25 +56,22 @@
        input = [torch.randn(2, 2, 2)]
        self.run_test(
            TestModule(),
            input,
        )
-    
+
    def test_index_zero_index_three_dim_ITensor(self):
        class TestModule(nn.Module):
            def forward(self, x, index0):
                indices = [None, index0, None]
                out = torch.ops.aten.index.Tensor(x, indices)
                return out

        input = torch.randn(2, 2, 2)
        index0 = torch.randint(0, 1, (1, 1))
        index0 = index0.to(torch.int32)
-        self.run_test(
-            TestModule(),
-            [input, index0]
-        )
+        self.run_test(TestModule(), [input, index0])

    def test_index_zero_index_one_index_two_three_dim(self):
        class TestModule(nn.Module):
            def __init__(self):
                self.index0 = torch.randint(0, 1, (1, 1))

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@@ -2,11 +2,10 @@

import torch
import torch.nn as nn
from harness import DispatchTestCase
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requires switch back to .harness to pass CI

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok corrected, wil wait for CI then will merge.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@apbose apbose merged commit 7029e91 into main Nov 8, 2023
17 of 19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants