Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP16 outputs error of TensorRT 8.6.1.2 when running Roberta #3101

Open
DayDayupupupup opened this issue Jun 30, 2023 · 10 comments
Open

FP16 outputs error of TensorRT 8.6.1.2 when running Roberta #3101

DayDayupupupup opened this issue Jun 30, 2023 · 10 comments
Assignees
Labels
Accuracy Precision: FP16 triaged Issue has been triaged by maintainers

Comments

@DayDayupupupup
Copy link

Description

Since the INormalization layer was added in TRT8.6, I do some tests with the fp16's accuracy:

  1. First, I use huggingface‘s bert-base-cased, exported it to onnx(opset17). Then using polygraphy to test the accuracy of fp16. Output(last_hidden_state, pooler_output):Difference is within tolerance (rel=1e-05, abs=0.01)
  2. Then, I use roberta-base, and found that the fp16 results still had errors: PASSED | Output: 'pooler_output' | Difference is within tolerance (rel=1e-05, abs=0.01), FAILED | Output: 'last_hidden_state'

Environment

TensorRT Version: 8.6.1.2
NVIDIA GPU: A30
NVIDIA Driver Version: 510.47.03
CUDA Version: 11.6
Operating System: Ubuntu 20.04.2 LTS
Tensorflow Version (if applicable): 1.15.5
Container version: nvcr.io/nvidia/tensorrt:23.05-py3

Steps To Reproduce

Test1: roberta-base

  1. export roberta-base to onnx
from transformers import RobertaTokenizer, RobertaModel
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, padding='max_length', max_length=128, return_tensors='pt')
output = model(**encoded_input)

model.eval()
import torch
with torch.no_grad():
    torch.onnx.export(model,               
                      tuple(encoded_input.values()),            
                      "roberta_base_opset17.onnx",   
                      export_params=True,   
                      opset_version=17,      
                      do_constant_folding=True, 
                      input_names=['input_ids','input_mask'],  
                      output_names=['last_hidden_state', 'pooler_output'], 
                      dynamic_axes={'input_ids': {0: 'batch_size'},
                                    'input_mask': {0: 'batch_size'},
                                    'last_hidden_state': {0: 'batch_size'},
                                    'pooler_output': {0: 'batch_size'}})
  1. polygraphy run roberta_base_opset17.onnx --trt --onnxrt --atol 0.01 --pool-limit workspace:10G --fp16
[I]     Comparing Output: 'last_hidden_state' (dtype=float32, shape=(1, 128, 768)) with 'last_hidden_state' (dtype=float32, shape=(1, 128, 768))
[I]         Tolerance: [abs=0.01, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-06/30/23-08:09:48: last_hidden_state | Stats: mean=0.020138, std-dev=0.4103, var=0.16835, median=0.0063438, min=-2.6055 at (0, 0, 453), max=11.375 at (0, 9, 588), avg-magnitude=0.11272
[I]             ---- Histogram ----
                Bin Range      |  Num Elems | Visualization
                (-2.61, -1.21) |        213 |
                (-1.21, 0.191) |      92069 | ########################################
                (0.191, 1.59 ) |       5752 | ##
                (1.59 , 2.99 ) |         71 |
                (2.99 , 4.38 ) |          0 |
                (4.38 , 5.78 ) |         71 |
                (5.78 , 7.18 ) |          0 |
                (7.18 , 8.58 ) |         71 |
                (8.58 , 9.98 ) |          0 |
                (9.98 , 11.4 ) |         57 |
[I]         onnxrt-runner-N0-06/30/23-08:09:48: last_hidden_state | Stats: mean=0.020128, std-dev=0.40961, var=0.16778, median=0.0070637, min=-2.5995 at (0, 0, 453), max=11.349 at (0, 38, 588), avg-magnitude=0.11256
[I]             ---- Histogram ----
                Bin Range      |  Num Elems | Visualization
                (-2.61, -1.21) |        213 |
                (-1.21, 0.191) |      92069 | ########################################
                (0.191, 1.59 ) |       5752 | ##
                (1.59 , 2.99 ) |         71 |
                (2.99 , 4.38 ) |          0 |
                (4.38 , 5.78 ) |         71 |
                (5.78 , 7.18 ) |          0 |
                (7.18 , 8.58 ) |         71 |
                (8.58 , 9.98 ) |          0 |
                (9.98 , 11.4 ) |         57 |
[I]         Error Metrics: last_hidden_state
[I]             Minimum Required Tolerance: elemwise error | [abs=0.049532] OR [rel=2643.4] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.0011686, std-dev=0.0015184, var=2.3056e-06, median=0.00079408, min=1.4901e-08 at (0, 105, 34), max=0.049532 at (0, 69, 588), avg-magnitude=0.0011686
[I]                 ---- Histogram ----
                    Bin Range           |  Num Elems | Visualization
                    (1.49e-08, 0.00495) |      97111 | ########################################
                    (0.00495 , 0.00991) |        923 |
                    (0.00991 , 0.0149 ) |        142 |
                    (0.0149  , 0.0198 ) |         73 |
                    (0.0198  , 0.0248 ) |          2 |
                    (0.0248  , 0.0297 ) |         10 |
                    (0.0297  , 0.0347 ) |          7 |
                    (0.0347  , 0.0396 ) |         10 |
                    (0.0396  , 0.0446 ) |         17 |
                    (0.0446  , 0.0495 ) |          9 |
[I]             Relative Difference | Stats: mean=0.083636, std-dev=8.4975, var=72.207, median=0.01208, min=9.7178e-07 at (0, 20, 249), max=2643.4 at (0, 69, 485), avg-magnitude=0.083636
[I]                 ---- Histogram ----
                    Bin Range            |  Num Elems | Visualization
                    (9.72e-07, 264     ) |      98303 | ########################################
                    (264     , 529     ) |          0 |
                    (529     , 793     ) |          0 |
                    (793     , 1.06e+03) |          0 |
                    (1.06e+03, 1.32e+03) |          0 |
                    (1.32e+03, 1.59e+03) |          0 |
                    (1.59e+03, 1.85e+03) |          0 |
                    (1.85e+03, 2.11e+03) |          0 |
                    (2.11e+03, 2.38e+03) |          0 |
                    (2.38e+03, 2.64e+03) |          1 |
[E]         FAILED | Output: 'last_hidden_state' | Difference exceeds tolerance (rel=1e-05, abs=0.01)
[I]     Comparing Output: 'pooler_output' (dtype=float32, shape=(1, 768)) with 'pooler_output' (dtype=float32, shape=(1, 768))
[I]         Tolerance: [abs=0.01, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-06/30/23-08:09:48: pooler_output | Stats: mean=0.0042347, std-dev=0.21781, var=0.047442, median=0.01236, min=-0.64404 at (0, 165), max=0.58496 at (0, 509), avg-magnitude=0.17412
[I]         onnxrt-runner-N0-06/30/23-08:09:48: pooler_output | Stats: mean=0.0041949, std-dev=0.2177, var=0.047392, median=0.01219, min=-0.64402 at (0, 165), max=0.58522 at (0, 509), avg-magnitude=0.17403
[I]         Error Metrics: pooler_output
[I]             Minimum Required Tolerance: elemwise error | [abs=0.0042159] OR [rel=4.1577] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.00095319, std-dev=0.00070167, var=4.9234e-07, median=0.00081642, min=1.4156e-06 at (0, 245), max=0.0042159 at (0, 591), avg-magnitude=0.00095319
[I]             Relative Difference | Stats: mean=0.033797, std-dev=0.22256, var=0.049531, median=0.0056062, min=6.7266e-06 at (0, 245), max=4.1577 at (0, 167), avg-magnitude=0.033797
[I]         PASSED | Output: 'pooler_output' | Difference is within tolerance (rel=1e-05, abs=0.01)
[E]     FAILED | Mismatched outputs: ['last_hidden_state']

When I use real data, the error is even greater

import numpy as np
from polygraphy.json import save_json

# Option 1: Define a function that will yield feed_dicts (i.e. Dict[str, np.ndarray])
def load_data():
    for _ in range(1):
        yield {"input_ids": encoded_input['input_ids'].numpy(),
               "input_mask": encoded_input['attention_mask'].numpy()}  # Still totally real data

# Option 2: Create a JSON file containing the input data using the `save_json()` helper.
#   The input to `save_json()` should have type: List[Dict[str, np.ndarray]].
#   For convenience, we'll reuse our `load_data()` implementation to generate the list.
input_data = list(load_data())
save_json(input_data, "custom_inputs.json", description="custom input data")

then polygraphy run roberta_base_opset17.onnx --trt --onnxrt --atol 0.01 --pool-limit workspace:10G --fp16 --load-inputs custom_inputs.json

[I]     Comparing Output: 'last_hidden_state' (dtype=float32, shape=(1, 128, 768)) with 'last_hidden_state' (dtype=float32, shape=(1, 128, 768))
[I]         Tolerance: [abs=0.01, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-06/30/23-08:20:22: last_hidden_state | Stats: mean=0.018884, std-dev=0.41145, var=0.16929, median=0.0093536, min=-8.2969 at (0, 9, 77), max=12.07 at (0, 10, 588), avg-magnitude=0.11438
[I]             ---- Histogram ----
                Bin Range        |  Num Elems | Visualization
                (-8.3  , -6.26 ) |          5 |
                (-6.26 , -4.22 ) |          4 |
                (-4.22 , -2.18 ) |        117 |
                (-2.18 , -0.148) |       8235 | ###
                (-0.148, 1.89  ) |      89815 | ########################################
                (1.89  , 3.93  ) |          0 |
                (3.93  , 5.96  ) |          0 |
                (5.96  , 8     ) |          0 |
                (8     , 10    ) |          6 |
                (10    , 12.1  ) |        122 |
[I]         onnxrt-runner-N0-06/30/23-08:20:22: last_hidden_state | Stats: mean=0.018878, std-dev=0.41122, var=0.1691, median=0.0091678, min=-8.2829 at (0, 9, 77), max=12.076 at (0, 10, 588), avg-magnitude=0.11435
[I]             ---- Histogram ----
                Bin Range        |  Num Elems | Visualization
                (-8.3  , -6.26 ) |          5 |
                (-6.26 , -4.22 ) |          4 |
                (-4.22 , -2.18 ) |        117 |
                (-2.18 , -0.148) |       8235 | ###
                (-0.148, 1.89  ) |      89815 | ########################################
                (1.89  , 3.93  ) |          0 |
                (3.93  , 5.96  ) |          0 |
                (5.96  , 8     ) |          0 |
                (8     , 10    ) |          6 |
                (10    , 12.1  ) |        122 |
[I]         Error Metrics: last_hidden_state
[I]             Minimum Required Tolerance: elemwise error | [abs=0.046174] OR [rel=62.955] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.00074002, std-dev=0.00098861, var=9.7735e-07, median=0.0005722, min=1.1176e-08 at (0, 10, 666), max=0.046174 at (0, 7, 77), avg-magnitude=0.00074002
[I]                 ---- Histogram ----
                    Bin Range           |  Num Elems | Visualization
                    (1.12e-08, 0.00462) |      97874 | ########################################
                    (0.00462 , 0.00923) |        178 |
                    (0.00923 , 0.0139 ) |        124 |
                    (0.0139  , 0.0185 ) |        119 |
                    (0.0185  , 0.0231 ) |          4 |
                    (0.0231  , 0.0277 ) |          3 |
                    (0.0277  , 0.0323 ) |          0 |
                    (0.0323  , 0.0369 ) |          0 |
                    (0.0369  , 0.0416 ) |          1 |
                    (0.0416  , 0.0462 ) |          1 |
[I]             Relative Difference | Stats: mean=0.15335, std-dev=2.3405, var=5.4779, median=0.0082764, min=2.8881e-07 at (0, 10, 666), max=62.955 at (0, 12, 85), avg-magnitude=0.15335
[I]                 ---- Histogram ----
                    Bin Range        |  Num Elems | Visualization
                    (2.89e-07, 6.3 ) |      97836 | ########################################
                    (6.3     , 12.6) |        234 |
                    (12.6    , 18.9) |          1 |
                    (18.9    , 25.2) |        117 |
                    (25.2    , 31.5) |          0 |
                    (31.5    , 37.8) |          0 |
                    (37.8    , 44.1) |          0 |
                    (44.1    , 50.4) |          0 |
                    (50.4    , 56.7) |          0 |
                    (56.7    , 63  ) |        116 |
[E]         FAILED | Output: 'last_hidden_state' | Difference exceeds tolerance (rel=1e-05, abs=0.01)
[I]     Comparing Output: 'pooler_output' (dtype=float32, shape=(1, 768)) with 'pooler_output' (dtype=float32, shape=(1, 768))
[I]         Tolerance: [abs=0.01, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-06/30/23-08:20:22: pooler_output | Stats: mean=0.0019211, std-dev=0.22539, var=0.050801, median=-0.0029383, min=-0.58057 at (0, 630), max=0.57764 at (0, 82), avg-magnitude=0.18478
[I]         onnxrt-runner-N0-06/30/23-08:20:22: pooler_output | Stats: mean=0.0019204, std-dev=0.22572, var=0.05095, median=-0.0030782, min=-0.58187 at (0, 630), max=0.57884 at (0, 680), avg-magnitude=0.18506
[I]         Error Metrics: pooler_output
[I]             Minimum Required Tolerance: elemwise error | [abs=0.0013217] OR [rel=0.65804] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.00032893, std-dev=0.00025306, var=6.404e-08, median=0.0002878, min=4.3353e-07 at (0, 567), max=0.0013217 at (0, 472), avg-magnitude=0.00032893
[I]             Relative Difference | Stats: mean=0.005479, std-dev=0.030208, var=0.00091252, median=0.0019026, min=3.8931e-06 at (0, 377), max=0.65804 at (0, 736), avg-magnitude=0.005479
[I]         PASSED | Output: 'pooler_output' | Difference is within tolerance (rel=1e-05, abs=0.01)
[E]     FAILED | Mismatched outputs: ['last_hidden_state']

Test2: chinese-roberta-wwm-ext

Relevant Files: Download tensorflow ckpt at below link:
Model link: chinese-roberta-wwm-ext tensorflow ckpt

As mentioned in this question #2466,bert4keras is still used to process the model

2.1 Create savedmodel

# tf1.15.5(gpu) 
# bert4keras=0.11.4
import os
os.environ['TF_KERAS'] = '1'
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.backend import keras, K
import tensorflow as tf

# load RoBERTa
model = build_transformer_model(
    config_path="bert_config.json",
    checkpoint_path='bert_model.ckpt',
    sequence_length=128,
    #model='roberta',
    #with_mlm=False,
    return_keras_model=False
)

bert_output = keras.layers.Dense(units=1)(model.output)
bert_output = keras.layers.Lambda(lambda x : K.squeeze(x, axis=2))(bert_output)
model = keras.models.Model(model.input, bert_output)

sess = K.get_session()
print([i.op.name for i in model.input])
print(model.output)
input0 = tf.get_default_graph().get_tensor_by_name("Input-Token:0")
input1 = tf.get_default_graph().get_tensor_by_name("Input-Segment:0")
output1 = tf.get_default_graph().get_tensor_by_name("lambda/Squeeze:0")

inputs = {"Input-Token": input0,"Input-Segment": input1}
# 3. save
tf.saved_model.simple_save(sess,
                         'saved_model',
                         inputs=inputs,
                         outputs=outputs)

2.2 Create onnx model with tf2onnx(1.13.0)

python -m tf2onnx.convert --saved-model saved_model --output roberta_wwm_ext_opset17.onnx --opset 17

2.3 fuse layernorm

Because tf2onnx splits layernorm, it needs to be merged manually. (fp16 result is wrong without fuse layernorm)

import onnx
from onnx import numpy_helper
from onnx import helper
import onnx_graphsurgeon as gs

model_path = "roberta_wwm_ext_opset17.onnx"
onnx_model = onnx.load(model_path)

graph = gs.import_onnx(onnx_model)

# get splited LayerNormalization
ln_inputs = []
betas = []
gamas = []
ln_outputs = []

for node in graph.nodes:
    
    # get epsilon 1e-12
    # if node.op == 'Add' and ('Norm/add' in node.name) and ('add_1' not in node.name):
    #     epsilon = node.inputs[1].values
    #     print(epsilon)
        
    # get B, Scale , ln_output
    if node.op == 'Add' and 'Norm/add_1' in node.name:
        B = node.inputs[1]
        # print(B.name)
        Scale = node.i().inputs[1]
        # print(Scale.name)
        ln_output = node.outputs
        
        gamas.append(Scale)
        betas.append(B)
        ln_outputs.append(ln_output)
        node.inputs.clear()
        
    # get ln_input
    if node.op == 'Sub' and 'Norm/sub' in node.name:
        for inp in node.inputs:
            if 'add' in inp.name:
                ln_input = inp
                # print(ln_input.name)
                ln_inputs.append(ln_input)
                node.outputs.clear()
            
assert len(ln_inputs)==len(betas)==len(gamas)==len(ln_outputs)    
for i in range(len(ln_inputs)):
    fused_node = gs.Node(
        op="LayerNormalization",
        inputs=[
            ln_inputs[i],  # input
            gamas[i],     # gamma
            betas[i],         # beta
        ],
        outputs=ln_outputs[i],
        attrs={'axis':-1, 'epsilon':1e-12})
    
    graph.nodes.append(fused_node)

for node in graph.nodes:
    if not node.inputs:
        node.outputs.clear()
graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "roberta_wwm_ext_opset17_fuse_ln.onnx")
print('done')

2.4 polygraphy run roberta_wwm_ext_opset17_fuse_ln.onnx --trt --onnxrt --atol 0.01 --pool-limit workspace:10G --fp16

[I]     Comparing Output: 'lambda' (dtype=float32, shape=(1, 128)) with 'lambda' (dtype=float32, shape=(1, 128))
[I]         Tolerance: [abs=0.01, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-06/30/23-09:12:15: lambda | Stats: mean=0.55965, std-dev=0.15921, var=0.025347, median=0.55591, min=0.245 at (0, 80), max=1.4902 at (0, 0), avg-magnitude=0.55965
[I]             ---- Histogram ----
                Bin Range      |  Num Elems | Visualization
                (0.245, 0.37 ) |         11 | ##########
                (0.37 , 0.494) |         33 | ###############################
                (0.494, 0.619) |         42 | ########################################
                (0.619, 0.744) |         32 | ##############################
                (0.744, 0.869) |          9 | ########
                (0.869, 0.993) |          0 |
                (0.993, 1.12 ) |          0 |
                (1.12 , 1.24 ) |          0 |
                (1.24 , 1.37 ) |          0 |
                (1.37 , 1.49 ) |          1 |
[I]         onnxrt-runner-N0-06/30/23-09:12:15: lambda | Stats: mean=0.56168, std-dev=0.15942, var=0.025416, median=0.55772, min=0.24882 at (0, 80), max=1.4923 at (0, 0), avg-magnitude=0.56168
[I]             ---- Histogram ----
                Bin Range      |  Num Elems | Visualization
                (0.245, 0.37 ) |         11 | ##########
                (0.37 , 0.494) |         32 | ##############################
                (0.494, 0.619) |         42 | ########################################
                (0.619, 0.744) |         33 | ###############################
                (0.744, 0.869) |          8 | #######
                (0.869, 0.993) |          1 |
                (0.993, 1.12 ) |          0 |
                (1.12 , 1.24 ) |          0 |
                (1.24 , 1.37 ) |          0 |
                (1.37 , 1.49 ) |          1 |
[I]         Error Metrics: lambda
[I]             Minimum Required Tolerance: elemwise error | [abs=0.0103] OR [rel=0.018481] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.0029195, std-dev=0.0020367, var=4.1483e-06, median=0.0025767, min=9.1791e-06 at (0, 69), max=0.0103 at (0, 104), avg-magnitude=0.0029195
[I]                 ---- Histogram ----
                    Bin Range           |  Num Elems | Visualization
                    (9.18e-06, 0.00104) |         27 | ########################################
                    (0.00104 , 0.00207) |         23 | ##################################
                    (0.00207 , 0.0031 ) |         25 | #####################################
                    (0.0031  , 0.00413) |         19 | ############################
                    (0.00413 , 0.00515) |         15 | ######################
                    (0.00515 , 0.00618) |         11 | ################
                    (0.00618 , 0.00721) |          5 | #######
                    (0.00721 , 0.00824) |          2 | ##
                    (0.00824 , 0.00927) |          0 |
                    (0.00927 , 0.0103 ) |          1 | #
[I]             Relative Difference | Stats: mean=0.005554, std-dev=0.0040631, var=1.6508e-05, median=0.0051507, min=1.5757e-05 at (0, 69), max=0.018481 at (0, 94), avg-magnitude=0.005554
[I]                 ---- Histogram ----
                    Bin Range           |  Num Elems | Visualization
                    (1.58e-05, 0.00186) |         30 | ########################################
                    (0.00186 , 0.00371) |         20 | ##########################
                    (0.00371 , 0.00556) |         19 | #########################
                    (0.00556 , 0.0074 ) |         24 | ################################
                    (0.0074  , 0.00925) |         16 | #####################
                    (0.00925 , 0.0111 ) |          6 | ########
                    (0.0111  , 0.0129 ) |          5 | ######
                    (0.0129  , 0.0148 ) |          3 | ####
                    (0.0148  , 0.0166 ) |          4 | #####
                    (0.0166  , 0.0185 ) |          1 | #
[E]         FAILED | Output: 'lambda' | Difference exceeds tolerance (rel=1e-05, abs=0.01)
[E]     FAILED | Mismatched outputs: ['lambda']

Question
Bert-base is fine, so I'm not sure if this error was caused by layernorm or roberta.
Because on trt8.5, if I set LayerNorm plugin to fp32, the inference is correct.
However, on trt8.6, I tried to set the INormalization layer to fp32, then the entire model is on fp32, because the visualization engine found only one myelin layer.

What can be done to ensure the accuracy of roberta fp16?

@zerollzeng
Copy link
Collaborator

Have you seen some FP16 warnings in the polygraphy output? like LayerNorm or FP16 overflow etc.

@zerollzeng zerollzeng self-assigned this Jul 2, 2023
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Jul 2, 2023
@DayDayupupupup
Copy link
Author

only FP16 overflow warnnings, no LayerNorm warnnings of roberta_wwm_ext_opset17_fuse_ln.onnx

[W] TensorRT encountered issues when converting weights between types and that could affect accuracy.
[W] If this is not the desired behavior, please modify the weights or retrain with regularization to adjust the magnitude of the weights.
[W] Check verbose logs for the list of affected weights.
[W] - 131 weights are affected by this issue: Detected subnormal FP16 values.
[W] - 46 weights are affected by this issue: Detected values less than smallest positive FP16 subnormal value and converted them to the FP16 minimum subnormalized value.
[W] - 1 weights are affected by this issue: Detected finite FP32 values which would overflow in FP16 and converted them to the closest finite FP16 value.

@ttyio
Copy link
Collaborator

ttyio commented Aug 7, 2023

@DayDayupupupup , there are only 5 exponent bits in fp16 compare to 8 exponents bits in fp32. So we might overflow depends on the input data flow into operation like pow inside normalize layer. Let's switch to FP32 when there is accuracy issue caused by normalize layer. BTW we will have bf16 support in the next release, bf16 also has 8 exponents the same as fp32.

@DayDayupupupup
Copy link
Author

Thank you very much and look forward to the next release of TRT. But I am still confused. I understand that the normalization in bert is forced to fp32. The model structure is little difference between roberta and bert. But the normalization of roberta overflows, does it mean that the normalization of roberta does not force to fp32?

@ttyio
Copy link
Collaborator

ttyio commented Aug 8, 2023

@DayDayupupupup the model would overflow or not depends on the input data before the pow, one debug tip is to mark the pow input as network output, feed all your data, to check the data range before deciding the precision that we should use.

And for the model structure problem, we could use visualization tool to check the onnx, sometimes there are cast(fp32) nodes around the norm layer, which telling TRT to force FP32.

Hope this helps!

@DayDayupupupup
Copy link
Author

@ttyio As I mentioned above, I've fused the normalization process into a LayerNormalization(opset 17).
And after marking the LayerNormalization input as the new network output, the original network output is still wrong. So I suspect roberta's LayerNormalization doesn't force to fp32.

Visualizing the trt engine, there is only one myelin layer node. After building the fp16 engine with the LayerNormalizationlayer set to fp32, all the entire graph has fall back to fp32.

@ttyio
Copy link
Collaborator

ttyio commented Aug 9, 2023

@DayDayupupupup , by enable kOBEY_PRECISION_CONSTRAINTS, and set

LayerNormlization.precision = FP32  
LayerNormalization.setOutputType(0, FP32)

we only make this layernorm to FP32 precision, not impact other layers even if they are in the same node.

@DayDayupupupup
Copy link
Author

DayDayupupupup commented Aug 11, 2023

@ttyio I have always set it like this.
the built engine is also as large as the unquantified one, and the inference speed is not improved.

@FdyCN
Copy link

FdyCN commented Aug 22, 2023

After building the fp16 engine with the LayerNormalizationlayer set to fp32, all the entire graph has fall back to fp32.

Hi, @DayDayupupupup , I got the same error on BeIT model which is beased on Bert. In your case, many matmul weights are warned that they are affected by this issue: Detected subnormal FP16 values. However you just set LN force to be fp32,then the fp16-model result will be good regardless of matmul weights?

unfortunately, I try to set LN\Matmul\Reduce\Softmax (Those may have potential overflow) to be fp32 on BeIT, but still get wrong result.

@DayDayupupupup
Copy link
Author

@FdyCN ,In my case, only LN will overflow, and only a small number of weights are subnormal FP16 values, which will not cause a large error in the final result.
Are you using TRT8.6 and LN is a fused op LayerNormalization (opset17)?
In TRT8.4, LN is not an independent op, I set all the ops in the LN section to fp32, but the inference result is still wrong.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Accuracy Precision: FP16 triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants