Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dynamic shape support for aten.select.int #2990

Merged
merged 3 commits into from
Jul 31, 2024
Merged

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Jul 9, 2024

Description

Support dynamic shapes for aten.select.int. As shown below, since the arguments dim and index for select.int are int and not list, no reshaping is required.

- func: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)
  variants: function, method
  device_check: NoCheck
  device_guard: False
  dispatch:
    CompositeExplicitAutograd: select_symint
    SparseCsrCPU, SparseCsrCUDA: select_sparse_csr
    NestedTensorCPU, NestedTensorCUDA: select_nested
  tags: core

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@chohk88 chohk88 self-assigned this Jul 9, 2024
@chohk88 chohk88 requested a review from peri044 July 9, 2024 13:21
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jul 9, 2024
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark torch.ops.aten.select.int with supports_dynamic_shapes=True

output_shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, name, output_shape, input
)

index_value = np.array(index, dtype=np.int32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can index itself be dynamic (ITensor)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the PyTorch docs and the schema, the index is always an integer and cannot be a list or tuple of integers. The size of the indices_tensor created with index and index_value will always be a scalar. Therefore, it cannot be dynamic.

@chohk88
Copy link
Collaborator Author

chohk88 commented Jul 10, 2024

Mark torch.ops.aten.select.int with supports_dynamic_shapes=True

I have marked torch.ops.aten.select.int with supports_dynamic_shapes=True.

output_shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, name, output_shape, input
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If above asserts() are removed and full dynamic shape was used(e.g.(-1,-1,-1)), test worked.
I'm wondering if select on dynamic dim can be supported.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested the test you mentioned and confirmed that using -1, -1, -1 also passes the test cases successfully.

However, I didn't remove the asserts() because the test cases fail when the index is out of range, meaning the index is larger than the corresponding dimension in the dynamic input shape.

            (
                "success case",
                (1, 1, 1),
                (2, 2, 2),
                (3, 3, 3),
                torch.float,
                0,
                1,
            ),
            (
                "fail case",
                (1, 1, 1),
                (2, 2, 2),
                (3, 3, 3),
                torch.float,
                0,
                3,
            ),

It seems that modifying the index like dimension (dim = get_positive_dim(cast(int, dim), ranks)) when the index is greater than the size would solve the issue. Do you have an example that handles it this way?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cannot check invalid index for dynamic dim. Error will happen in runtime.
Maybe we can check index for only static dim.
if DYNAMIC_DIM != input.shape[dim] and index >= input.shape[dim]:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tested by removing the assert() and raise RuntimeError() statements and tested with (-1, -1, -1).

The test case mentioned as a success case above passes, but the test case mentioned as a fail case fails. The reason for this is that in the fail case, the size of the 0-th dimension is 3, but the index selects positions larger than that (index=3, 4th position). The current converter does not handle this case where the index is larger than the size of the dimension specified by dim, as in the 'fail case'. (Note: It might be possible to handle this using the slice layer, lt (less than) and div functions, but currently, it is handled with an assert statement.)

Therefore, I left the assert() and raise RuntimeError() statement without removing it.

Copy link
Collaborator Author

@chohk88 chohk88 Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! I misunderstood. If the index is larger than the size of the dimension specified by dim, it won't work in PyTorch either, so we don't need to handle this case. Therefore, we don't need to consider test cases like the 'fail case' mentioned above. This means that a dynamic shape is supported for all dimensions.

Pytorch example of correct usage of select.int

image

Pytorch example of IndexError occurring case of select.int

image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keehyuna Thanks for the suggestion!

Since PyTorch already raises an error when the index exceeds the input size, there’s no need for us to check this here. I removed those checks. Additionally, following your suggestion, I’ve added test cases to fully support dynamic shapes.

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this casting here since it will always be int?

dim = get_positive_dim(cast(int, dim), ranks)

We could also change the type of dim from Shape to int in the function signature.

indices_tensor = ctx.net.add_constant(
index_value.shape, to_numpy(index_value)
).get_output(0)
indices_tensor = ctx.net.add_constant(index_value.shape, index_value).get_output(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use get_trt_tensor call here ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

dim and index are now int (not Shape), and I've changed to using get_trt_tensor for indices_tensor.

@chohk88 chohk88 requested review from keehyuna and peri044 July 24, 2024 00:25
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@peri044 peri044 merged commit c99c966 into main Jul 31, 2024
40 of 65 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants