Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aten.atan2 converter #2689

Merged
merged 2 commits into from
Apr 12, 2024
Merged

feat: support aten.atan2 converter #2689

merged 2 commits into from
Apr 12, 2024

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Mar 13, 2024

Description

New feature to support aten.atan2 converter. I also add test case including edge case for inputs are zero.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Mar 13, 2024
@github-actions github-actions bot requested a review from apbose March 13, 2024 13:16
@chohk88 chohk88 self-assigned this Mar 13, 2024
Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. It seems that get_value_by_condition is a helper function. Can you move it to .../dynamo/conversion/converter_utils.py?

source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
other: TRTTensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function support other being a Python float or does it require TRTTensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it requires TRTTensor because both input and output are Tensor

image

Comment on lines 736 to 747
def get_value_by_condition(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
other: TRTTensor,
condition: TRTTensor,
) -> TRTTensor:
select_layer = ctx.net.add_select(condition, input, other)
set_layer_name(select_layer, target, name + "_select", source_ir)
return select_layer.get_output(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed - replace with usage of impl.condition.select, from here:

def select(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
other: TRTTensor,
condition: TRTTensor,
) -> TRTTensor:
select_layer = ctx.net.add_select(condition, input, other)
set_layer_name(select_layer, target, name + "_select", source_ir)
return select_layer.get_output(0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, you may need where:

def where(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: Union[TRTTensor, np.ndarray, torch.Tensor],
other: Union[TRTTensor, np.ndarray, torch.Tensor],
condition: Union[TRTTensor, np.ndarray, torch.Tensor],
) -> TRTTensor:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed get_value_by_condition and replaced it with impl.condition.select.

@@ -1391,6 +1391,24 @@ def aten_ops_atanh(
)


@dynamo_tensorrt_converter(torch.ops.aten.atan2.default)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you require one or both input tensors to be TRTTensor, consider using:

@enforce_tensor_types(
{
0: (TRTTensor,),
}
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add enforce_tensor_types for both input and other.

@chohk88 chohk88 linked an issue Mar 25, 2024 that may be closed by this pull request
@chohk88
Copy link
Collaborator Author

chohk88 commented Mar 25, 2024

Overall looks good to me. It seems that get_value_by_condition is a helper function. Can you move it to .../dynamo/conversion/converter_utils.py?

I removed get_value_by_condition and replaced it with impl.condition.select.

@chohk88
Copy link
Collaborator Author

chohk88 commented Mar 25, 2024

@zewenli98 @gs-olive I've made the changes you suggested. I also noticed I missed the case when input or other is a single constant, so I fixed that using broadcast function. And, I added test cases to check it's all working right. Thanks for your helpful suggestions!

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@narendasan narendasan merged commit 821ff91 into main Apr 12, 2024
16 of 21 checks passed
@narendasan narendasan deleted the atan2_converter_update branch April 12, 2024 00:36
peri044 pushed a commit that referenced this pull request Apr 19, 2024
laikhtewari pushed a commit that referenced this pull request May 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

aten.atan2
5 participants