-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
❓ [Question] Running LayerNorm in fp16 #2730
Comments
Here is a minimal reproducible example: import torch
import torch.nn as nn
class LayerNormFP32(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
class Model(nn.Module):
def __init__(self, hidden_dim: int = 1024):
super().__init__()
self.hidden_dim = hidden_dim
self.ln = LayerNormFP32(hidden_dim, bias=False)
def forward(self, x: torch.Tensor):
return self.ln(x)
def to_jit_trace(
self,
device: str = "cpu",
dtype: torch.dtype = torch.float,
batch_size: int = 2,
) -> torch.jit.ScriptModule:
dummy_inputs = torch.randn((batch_size, self.hidden_dim), dtype=dtype, device=device)
self.to(device)
self.eval()
with torch.no_grad():
outputs1 = self(*dummy_inputs)
trace = torch.jit.trace(self, dummy_inputs, check_trace=False)
outputs2 = trace(*dummy_inputs)
assert torch.allclose(outputs1, outputs2)
return trace, dummy_inputs
def to_tensorrt(
self,
batch_size,
precisions: set[torch.dtype] = {
torch.float,
torch.half
},
):
import torch_tensorrt
dtype = torch.float
if torch.half in precisions:
dtype = torch.half
with torch.cuda.amp.autocast(enabled=True):
trace, dummy_inputs = self.to_jit_trace("cuda", dtype, batch_size=batch_size)
trt = torch_tensorrt.compile(
trace,
input_signature=(torch_tensorrt.Input(shape=dummy_inputs.shape, dtype=dummy_inputs.dtype),),
enabled_precisions=precisions,
require_full_compilation=True,
truncate_long_and_double=True,
)
return trt and fp32 gives the same outputs, fp16 does not (while producing the warnings): model = Model()
batch_size = 1
trt_16 = model.to_tensorrt(batch_size=batch_size, precisions={torch.float, torch.half})
with torch.cuda.amp.autocast(enabled=True):
trace_fp16, dummy_inputs_16 = model.to_jit_trace("cuda", torch.half, batch_size=batch_size)
trt_32 = model.to_tensorrt(batch_size=batch_size, precisions={torch.float})
trace_fp32, dummy_inputs_32 = model.to_jit_trace("cuda", torch.float, batch_size=batch_size)
with torch.no_grad():
# False
# tensor(0.0020, device='cuda:0', dtype=torch.float16)
print(torch.allclose(trace_fp16(dummy_inputs_16), trt_16(dummy_inputs_16)))
print((trace_fp16(dummy_inputs_16) - trt_16(dummy_inputs_16)).abs().max())
# True
# tensor(2.9802e-08, device='cuda:0')
print(torch.allclose(trace_fp32(dummy_inputs_32), trt_32(dummy_inputs_32)))
print((trace_fp32(dummy_inputs_32) - trt_32(dummy_inputs_32)).abs().max()) |
Hi @Tomiinek, I refactored the layer norm with INormalization Layer. Could you confirm if this works for you? thanks! |
Hello @zewenli98, thank you! I am having issues compiling the latest code on my environment (python 3.11, torch 2.2), so I tried to use the wheel from gh actions associated to the PR (this one https://github.com/pytorch/TensorRT/actions/runs/8711801688/artifacts/1419799870), but also without a success. Simply patching the file in site_packages of the latest release did not help (i.e. the fp16 issue persists) Is there another way to check it out or to catch it in tests? |
@Tomiinek It seems the |
Hi @zewenli98 thank you for our patience. I tried something like:
but it says
because I am still on 2.2.0. So I tried to upgrade to 2.3.0dev, but I am not able to import the package:
Do you have any tips on how to install or try out the latest and greatest code or builds? These are my versions:
|
Hi @Tomiinek, For this error:
This is because you might install mismatched
and then build torch-tensorrt again with:
Besides, you can try to use:
|
Hi @zewenli98 , thanks for your responses! I'm trying to create a wheel for @Tomiinek to test out the fix. I'm opting for Docker, as local compilation gave me some weird errors about incompatible hashes when downloading tarballs from Nvidia. I've changed the libtorch sections per your suggestion, checked out your PR branch and ran
Perhaps it would be easier to merge the PR and we'll test if the nightly wheel of TensorRT works? Compiling Torch-TensorRT locally seems to be pretty complicated. |
Hello @zewenli98 , I installed the current release with python 3.10 so that I can try out at least dynamo. I tried to compile a single linear layer with torchscript frontend, in fp32. The compiled module gives correct outputs (i.e. the same as raw), but not in fp16, which I believed changed from the previous release which was giving correct outputs but ignoring casting in layer norms. I tried to compile a single linear layer with dynamo in fp32. I am not getting correct outputs and the compiled module is 3x slower than the one compiled with torchscript frontend. The layernorm issue persists with torchscript and dynamo does not produce warnings but still produces weird outputs. I am really confused, could you please help me and provide code snippets that I could run and at the same time work for you? Specifically:
Or at least tell me if the code I posted above works for you with the latest release, or what I am doing wrong in there 🤷 CC: @narendasan |
@narendasan @peri044 Can you guys take a look? |
❓ Question
What you have already tried
I am trying to convert a transformer model to TRT in fp16 (fp32 works fine 🙂). It includes bunch of LayerNorms, all of them have explicit casting of inputs to fp32, i.e:
I am getting warnings about precisions of the layers:
I checked dtype of the mentioned weights in the trace that I pass to
torch_tensorrt.compile
and they are correctly in fp32, even though the warnings state the opposite.The warning suggets two solutions (use INormalizationLayer or force FP32 precisions) but I have no idea ho to achieve it.
This might be a related: #2509 (or NVIDIA/TensorRT#3101)
Any ideas how to resolve or debug this issue?
Environment
The text was updated successfully, but these errors were encountered: