Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add handling for ITensor mean and var in batch_norm #3099

Merged
merged 3 commits into from
Aug 22, 2024

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Aug 19, 2024

Description

Support ITensor type running_mean and running_var arguments for Batch Norm converter.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@chohk88 chohk88 self-assigned this Aug 19, 2024
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Aug 19, 2024
Comment on lines 37 to 40
weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],
running_mean: Optional[Union[torch.Tensor, np.ndarray]],
running_var: Optional[Union[torch.Tensor, np.ndarray]],
running_mean: Union[TRTTensor, Optional[Union[torch.Tensor, np.ndarray]]],
running_var: Union[TRTTensor, Optional[Union[torch.Tensor, np.ndarray]]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the schema, these types seem to be:

weight: Optional[Union[TRTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTensor, torch.Tensor, np.ndarray]],
running_mean: Union[TRTensor, torch.Tensor, np.ndarray],
running_var: Union[TRTensor, torch.Tensor, np.ndarray],

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that torch.ops.aten.native_batch_norm.default and torch.ops.aten.batch_norm.default reuse the function. Since they require running_mean and running_var to be optional, you can put all these types to Optional[Union[TRTensor, torch.Tensor, np.ndarray]]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comments! I've resolved the issue based on your feedback.

Comment on lines 55 to 64
if isinstance(running_mean, TRTTensor) or isinstance(running_var, TRTTensor):
# Default values if weight, bias, running_mean, running_var are None
if weight is None:
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight", input.dtype)
if bias is None:
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias", input.dtype)
if running_mean is None:
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean", input.dtype)
if running_var is None:
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var", input.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Do we need to cast these parameters to the type of input? I think there's probably a case that input is int type while weight, bias, running_mean, and/or running_var are float type. It seems problematic to force cast float to int. The dtype is optional so you can just leave it blank.

  2. weight and bias could be ITensor as well right?

Copy link
Collaborator

@keehyuna keehyuna Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

->>Do we need to cast these parameters to the type of input?
I thinks this is required for strongly-typed-networks, different type can be allowed for weak typed network.
I noticed it by enabling strongly typed networks with some model. Here is some changes to keep same type in ops. I only saw float32 and half float, I'm now sure if float and int tensor in same layer is possible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

->>Do we need to cast these parameters to the type of input?
If input_val is a TRTTensor, the input is returned unchanged (code), so the type stays the same. If input_val is something else, create_constant converts the type using to_numpy (code). However, if the value is an np.ndarray or torch.Tensor, it keeps the original type (code and code). This means setting input.dtype has no effect.

I confirmed this behavior when I removed the input.dtype argument.

As for the issue @keehyuna mentioned, it's new to me, so I'll investigate.

->> weight and bias could be ITensor as well right?
Additionally, according to the schema, weight and bias can't be ITensor, but the converter works fine. I've tested it, and it works successfully.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

->> weight and bias could be ITensor as well right?
Additionally, according to the schema, weight and bias can't be ITensor, but the converter works fine. I've tested it, and it works successfully.

@chohk88 Did I miss something? The schema is:

- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
- func: _native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)

My understanding to Tensor? weight, Tensor? bias is that they could be None, ITensor, torch.Tensor, or np.ndarray. Please correct me if I'm wrong.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I misunderstood that. You're right—weight, bias, running_mean, and running_var can all be TRTTensor. I’ve combined the separate converter and added some test cases for this.

@chohk88 chohk88 force-pushed the converter_batch_norm_with_TRTTensor_mean_var branch from e1773b4 to 9133df8 Compare August 20, 2024 03:42
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chohk88
Copy link
Collaborator Author

chohk88 commented Aug 21, 2024

LGTM

If the CI results show no more batch_norm errors, it's good to merge.

@chohk88
Copy link
Collaborator Author

chohk88 commented Aug 22, 2024

@zewenli98 Although it's beyond the scope of this PR, I noticed something regarding the converters for layer_norm, group_norm, and native_group_norm.

Currently, weight and bias are defined as:

weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],

Is this incorrect? Fortunately, unlike batch norm, there's no issue with adding eps or applying to_numpy on the weight or bias of TRTTensor (or ITensor), so it doesn't seem to cause any problems with the converter. However, it seems like the case where the default value for weight in layer_norm is None might not be handled properly.

Here are schema:

@peri044
Copy link
Collaborator

peri044 commented Aug 22, 2024

Merging this as the other failures are unrelated.

@peri044 peri044 merged commit 66511da into main Aug 22, 2024
49 of 67 checks passed
@zewenli98
Copy link
Collaborator

@chohk88 Thanks for pointing out the issue. I think your understanding is correct. Although there's no errors out now, they should be Optional[Union[TRTTensor, torch.Tensor, np.ndarray]]. Since the Issue #3114 has been opened above to track it, could you modify them to Optional[Union[TRTTensor, torch.Tensor, np.ndarray]] like what you did for the batch_norm?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants