-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support aten.atan2 converter #2689
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me. It seems that get_value_by_condition
is a helper function. Can you move it to .../dynamo/conversion/converter_utils.py
?
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
other: TRTTensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this function support other
being a Python float or does it require TRTTensor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_value_by_condition( | ||
ctx: ConversionContext, | ||
target: Target, | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
other: TRTTensor, | ||
condition: TRTTensor, | ||
) -> TRTTensor: | ||
select_layer = ctx.net.add_select(condition, input, other) | ||
set_layer_name(select_layer, target, name + "_select", source_ir) | ||
return select_layer.get_output(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed - replace with usage of impl.condition.select
, from here:
TensorRT/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py
Lines 102 to 113 in 76c9ebd
def select( | |
ctx: ConversionContext, | |
target: Target, | |
source_ir: Optional[SourceIR], | |
name: str, | |
input: TRTTensor, | |
other: TRTTensor, | |
condition: TRTTensor, | |
) -> TRTTensor: | |
select_layer = ctx.net.add_select(condition, input, other) | |
set_layer_name(select_layer, target, name + "_select", source_ir) | |
return select_layer.get_output(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, you may need where
:
TensorRT/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py
Lines 18 to 26 in 76c9ebd
def where( | |
ctx: ConversionContext, | |
target: Target, | |
source_ir: Optional[SourceIR], | |
name: str, | |
input: Union[TRTTensor, np.ndarray, torch.Tensor], | |
other: Union[TRTTensor, np.ndarray, torch.Tensor], | |
condition: Union[TRTTensor, np.ndarray, torch.Tensor], | |
) -> TRTTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed get_value_by_condition
and replaced it with impl.condition.select.
@@ -1391,6 +1391,24 @@ def aten_ops_atanh( | |||
) | |||
|
|||
|
|||
@dynamo_tensorrt_converter(torch.ops.aten.atan2.default) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you require one or both input tensors to be TRTTensor
, consider using:
TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Lines 62 to 66 in 6b8ccbc
@enforce_tensor_types( | |
{ | |
0: (TRTTensor,), | |
} | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I add enforce_tensor_types
for both input and other.
6b8ccbc
to
414b806
Compare
I removed |
@zewenli98 @gs-olive I've made the changes you suggested. I also noticed I missed the case when input or other is a single constant, so I fixed that using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
New feature to support aten.atan2 converter. I also add test case including edge case for inputs are zero.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: