forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
build.py
933 lines (870 loc) · 32.3 KB
/
build.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import math
import time
from pathlib import Path
import torch
import torch.multiprocessing as mp
from visualize import to_onnx
from weight import (get_scaling_factors, load_from_awq, load_from_gptq,
load_from_hf, load_from_sq)
import tensorrt_llm
from tensorrt_llm import profiler
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import ChatGLMHeadModel, quantize_model
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
def get_engine_name(model, dtype, tp_size, pp_size, rank):
if pp_size == 1:
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size,
pp_size, rank)
def find_engines(dir: Path,
model_name: str = "*",
dtype: str = "*",
tp_size: str = "*",
rank: str = "*"):
template = f"{model_name}_{dtype}_tp{tp_size}_rank{rank}.engine"
return list(dir.glob(template))
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
with open(path, 'wb') as f:
f.write(engine)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
def truncate_input_output_len(
max_input_len,
max_output_len,
max_seq_length_from_config,
is_fixed_max_position_length=False,
):
max_seq_length = max_seq_length_from_config
if max_input_len >= max_seq_length_from_config:
print("Truncate max_input_len as %d" % (max_seq_length_from_config - 1))
max_input_len = max_seq_length_from_config - 1
max_output_len = 1
elif max_input_len + max_output_len > max_seq_length_from_config:
print("Truncate max_output_len as %d" %
(max_seq_length_from_config - max_input_len))
max_output_len = max_seq_length_from_config - max_input_len
elif not is_fixed_max_position_length:
max_seq_length = max_input_len + max_output_len
return max_input_len, max_output_len, max_seq_length
def parse_arguments(args):
parser = argparse.ArgumentParser()
# Arguments above model
parser.add_argument(
'--model_name',
'-m',
type=str,
choices=[
'chatglm_6b', 'chatglm2_6b', 'chatglm2_6b_32k', 'chatglm3_6b',
'chatglm3_6b_base', 'chatglm3_6b_32k', 'glm_2b', 'glm_10b',
'glm_10b_chinese'
],
help='Name of model, use "_" rather than "-" to connect the parts',
)
parser.add_argument(
'--gpus_per_node',
type=int,
default=8,
help='',
)
parser.add_argument(
'--world_size',
'-ws',
type=int,
default=1,
help='World size, only tensor parallelism is supported now',
)
parser.add_argument(
'--tp_size',
'-tp',
type=int,
default=1,
help='Tensor parallelism size, world_size must be set if TP > 1',
)
parser.add_argument(
'--pp_size',
'-pp',
type=int,
default=1,
help='Pipeline parallelism size, not supported now',
)
parser.add_argument(
'--model_dir',
'-i',
type=Path,
default=None,
help='Path of model files from HF',
)
parser.add_argument(
'--quant_ckpt_path',
type=Path,
default=None,
help='File of AWQ (.npz) generated by quantize.py',
)
parser.add_argument(
'--quantized_fp8_model_path',
type=str,
default=None,
help='File of quantization (.npz) generated by hf_chatglm_convert.py',
)
parser.add_argument(
'--output_dir',
'-o',
type=Path,
default='engine_outputs',
help='Path to save serialized engine and configuration json files')
parser.add_argument(
'--dtype',
type=str,
default='float16',
choices=['float32', 'float16', 'bfloat16'],
help='Data type of computation of the model',
)
parser.add_argument(
'--logits_dtype',
type=str,
default='float32',
choices=['float16', 'float32'],
help='Data type of logits exported from the model',
)
parser.add_argument(
'--strongly_typed',
default=False,
action='store_true',
help='Reduce building time for FP8 in TRT >=9.1.0.1',
)
parser.add_argument(
'--timing_cache',
type=str,
default='model.cache',
help='File of TensorRT timing cache',
)
parser.add_argument(
'--profiling_verbosity',
type=str,
default='layer_names_only',
choices=['layer_names_only', 'detailed', 'none'],
help=
'The profiling verbosity for the generated TRT engine. Set to detailed can inspect tactic choices and kernel parameters.'
)
parser.add_argument(
'--log_level',
type=str,
default='error',
choices=['verbose', 'info', 'warning', 'error', 'internal_error'],
help='Level of log information',
)
parser.add_argument(
'--builder_opt',
type=int,
default=None,
help='',
)
parser.add_argument(
'--parallel_build',
default=False,
action='store_true',
help='Build engines on multiple GPU simultaneously world_size > 1',
)
parser.add_argument(
'--enable_debug_output',
default=False,
action='store_true',
)
parser.add_argument(
'--visualize',
default=False,
action='store_true',
)
parser.add_argument(
'--random_seed',
type=int,
default=None,
help='Random seed for Torch',
)
# Arguments of model related
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--max_input_len', type=int, default=1024)
parser.add_argument('--max_output_len', type=int, default=1024)
parser.add_argument('--max_beam_width', type=int, default=1)
parser.add_argument(
'--max_num_tokens',
type=int,
default=2**32,
help='The max number of tokens supported by the engine',
)
parser.add_argument(
'--use_gpt_attention_plugin',
nargs='?',
const='float16',
default='float16',
choices=['float32', 'float16', 'bfloat16', False],
help='Activate attention plugin with optional data type',
)
parser.add_argument(
'--use_gemm_plugin',
nargs='?',
const='float16',
type=str,
default='float16',
choices=['float32', 'float16', 'bfloat16', False],
help='Activate GEMM plugin with optional data type',
)
parser.add_argument(
'--use_layernorm_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float32', 'float16', 'bfloat16', False],
help=
'Activate layernorm plugin for ChatGLM-6B / GLM-10B models with optional data type',
)
parser.add_argument(
'--use_rmsnorm_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float32', 'float16', 'bfloat16', False],
help=
'Activate rmsnorm plugin for ChatGLM2-6B* / ChatGLM3-6B* models with optional data type',
)
parser.add_argument(
'--enable_context_fmha',
default=False,
action='store_true',
)
parser.add_argument(
'--enable_context_fmha_fp32_acc',
default=False,
action='store_true',
)
parser.add_argument(
'--multi_block_mode',
default=False,
action='store_true',
help='Split long kv sequence into multiple blocks \
(applied to generation MHA kernels). \
It is beneifical when batchxnum_heads cannot fully utilize GPU',
)
parser.add_argument(
'--gather_all_token_logits',
action='store_true',
default=False,
help='Enable both gather_context_logits and gather_generation_logits')
parser.add_argument('--gather_context_logits',
action='store_true',
default=False,
help='Gather context logits')
parser.add_argument('--gather_generation_logits',
action='store_true',
default=False,
help='Gather generation logits')
parser.add_argument(
'--use_custom_all_reduce',
action='store_true',
help=
'Activates latency-optimized algorithm for all-reduce instead of NCCL',
)
parser.add_argument(
'--remove_input_padding',
default=False,
action='store_true',
help='',
)
parser.add_argument(
'--paged_kv_cache',
action='store_true',
default=False,
help='Enable paged KV cache rather than contiguous KV cache',
)
parser.add_argument(
'--tokens_per_block',
type=int,
default=128,
help='Number of tokens per block in paged KV cache',
)
parser.add_argument(
'--use_inflight_batching',
action='store_true',
default=False,
help='Activate In-Flight-Batching mode of GPT Attention Plugin',
)
# Arguments about quantization
parser.add_argument(
'--use_weight_only',
default=False,
action='store_true',
help='Quantize GEMMs to INT4/INT8',
)
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4', 'int4_awq'],
help='Precision for the weights when using weight-only quantization',
)
parser.add_argument(
'--disable_weight_only_quant_plugin',
default=False,
action='store_true',
help=
'Use TensorRT OOTB implementation for weight quantization rather than plugins',
)
parser.add_argument(
'--use_smooth_quant',
default=False,
action='store_true',
help=
'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.'
'See --per_channel and --per_token for finer-grained quantization options.'
)
parser.add_argument(
'--per_token',
default=False,
action='store_true',
help=
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.',
)
parser.add_argument(
'--per_channel',
default=False,
action='store_true',
help=
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.',
)
parser.add_argument(
'--per_group',
default=False,
action='store_true',
help=
'By default, we use a single static scaling factor to scale weights in the int4 range. '
'per_group chooses at run time, and for each group, a custom scaling factor. '
'The flag is built for GPTQ/AWQ quantization.',
)
parser.add_argument(
'--group_size',
type=int,
default=128,
help='Group size used in GPTQ/AWQ quantization.',
)
parser.add_argument(
'--int8_kv_cache',
default=False,
action='store_true',
help=
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
)
parser.add_argument(
'--enable_fp8',
default=False,
action='store_true',
help='Use FP8 Linear layer for Attention QKV/Dense and MLP.',
)
parser.add_argument(
'--fp8_kv_cache',
default=False,
action='store_true',
help=
'By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV'
)
args = parser.parse_args(args)
logger.set_level(args.log_level)
plugins_args = [
'use_gpt_attention_plugin',
'use_gemm_plugin',
'use_layernorm_plugin',
'use_rmsnorm_plugin',
]
for plugin_arg in plugins_args:
if getattr(args, plugin_arg) is None:
logger.info(
f"{plugin_arg} set without specifying data type. Using {args.dtype} automatically."
)
setattr(args, plugin_arg, args.dtype)
assert args.world_size == args.tp_size * args.pp_size # only TP is supported now
assert not (args.model_name is None and args.model_dir is None), \
"Either model name or model directory must be provided"
if args.model_dir is None:
args.model_dir = Path(args.model_name)
with open(args.model_dir / "config.json", "r") as f:
js = json.loads(f.read())
if args.model_name is None:
args.model_name = js["_name_or_path"].split("/")[-1].replace("-", "_")
if args.output_dir is None:
args.output_dir = Path("output_" + args.model_name)
if args.model_name in [
"chatglm_6b",
"glm_2b",
"glm_10b",
"glm_10b_chinese",
]:
assert args.max_input_len < js["max_sequence_length"]
if args.model_name in ["chatglm_6b"]:
args.apply_query_key_layer_scaling = False
args.apply_residual_connection_post_layernorm = False
args.ffn_hidden_size = js["inner_hidden_size"]
args.hidden_act = 'gelu'
args.hidden_size = js["hidden_size"]
args.linear_bias = True
args.max_input_len, args.max_output_len, args.max_seq_length = \
truncate_input_output_len(
args.max_input_len,
args.max_output_len,
min(args.max_num_tokens, js["max_sequence_length"]),
)
args.max_num_tokens = args.max_batch_size * args.max_seq_length
args.multi_block_mode = False
args.multi_query_mode = False
args.norm_epsilon = js["layernorm_epsilon"]
args.num_heads = js["num_attention_heads"]
args.num_kv_heads = js["num_attention_heads"]
args.num_layers = js["num_layers"]
args.qkv_bias = True
args.rmsnorm = False
args.rotary_embedding_scaling = 1.0
args.use_cache = js["use_cache"]
args.vocab_size = js["vocab_size"]
elif args.model_name in [
"chatglm2_6b",
"chatglm2_6b_32k",
"chatglm3_6b",
"chatglm3_6b_base",
"chatglm3_6b_32k",
]:
args.apply_query_key_layer_scaling = False
args.apply_residual_connection_post_layernorm = js[
"apply_residual_connection_post_layernorm"]
args.ffn_hidden_size = js["ffn_hidden_size"]
args.hidden_act = 'swiglu'
args.hidden_size = js["hidden_size"]
args.linear_bias = js["add_bias_linear"]
args.max_input_len, args.max_output_len, args.max_seq_length = \
truncate_input_output_len(
args.max_input_len,
args.max_output_len,
min(args.max_num_tokens, js["seq_length"]),
)
args.max_num_tokens = args.max_batch_size * args.max_seq_length
args.multi_block_mode = False
args.multi_query_mode = False # regardless of config.json
args.norm_epsilon = js["layernorm_epsilon"]
args.num_heads = js["num_attention_heads"]
args.num_kv_heads = js["multi_query_group_num"]
args.num_layers = js["num_layers"]
args.qkv_bias = js["add_qkv_bias"]
args.rmsnorm = js["rmsnorm"]
if args.model_name in ["chatglm2_6b_32k", "chatglm3_6b_32k"]:
args.rotary_embedding_scaling = js["rope_ratio"]
else:
args.rotary_embedding_scaling = 1.0
args.use_cache = js["use_cache"]
args.vocab_size = js["padded_vocab_size"]
elif args.model_name in ["glm_2b", "glm_10b", "glm_10b_chinese"]:
args.apply_query_key_layer_scaling = False
args.apply_residual_connection_post_layernorm = False
args.ffn_hidden_size = 4 * js["hidden_size"]
args.hidden_act = 'gelu'
args.hidden_size = js["hidden_size"]
args.linear_bias = True
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
args.max_input_len,
args.max_output_len,
min(args.max_num_tokens, js["max_sequence_length"]),
True,
)
args.max_num_tokens = args.max_batch_size * args.max_seq_length
args.multi_block_mode = False
args.multi_query_mode = False
args.norm_epsilon = 1.0e-5
args.num_heads = js["num_attention_heads"]
args.num_kv_heads = js["num_attention_heads"]
args.num_layers = js["num_layers"]
args.qkv_bias = True
args.rmsnorm = False
args.rotary_embedding_scaling = 1.0
args.use_cache = True
args.vocab_size = js["vocab_size"]
if args.use_inflight_batching:
if not args.use_gpt_attention_plugin:
args.use_gpt_attention_plugin = 'float16'
logger.info(
f"Using {args.use_gpt_attention_plugin} of GPT attention plugin for In-Flight-Batching mode"
)
if not args.remove_input_padding:
args.remove_input_padding = True
logger.info(
"Using remove_input_padding for In-Flight-Batching mode")
if not args.paged_kv_cache:
args.paged_kv_cache = True
logger.info("Using paged_kv_cache for In-Flight-Batching mode")
assert (math.log2(args.tokens_per_block).is_integer()
), "tokens_per_block must be power of 2"
if args.enable_context_fmha or args.enable_context_fmha_fp32_acc:
assert (args.tokens_per_block >=
128), "Context fMHA requires >= 128 tokens per block"
assert not (args.use_smooth_quant and args.use_weight_only), \
"SmoothQuant and INT8-weight-only can not be set at the same time"
if args.use_smooth_quant:
args.quant_mode = QuantMode.use_smooth_quant(
args.per_token,
args.per_channel,
)
elif args.use_weight_only:
args.quant_mode = QuantMode.from_description(
quantize_weights=True,
quantize_activations=False,
per_token=False,
per_channel=False,
per_group=args.per_group,
use_int4_weights="int4" in args.weight_only_precision,
)
else:
args.quant_mode = QuantMode(0)
if args.int8_kv_cache:
args.quant_mode = args.quant_mode.set_int8_kv_cache()
elif args.fp8_kv_cache:
args.quant_mode = args.quant_mode.set_fp8_kv_cache()
if args.enable_fp8:
args.quant_mode = args.quant_mode.set_fp8_qdq()
logger.info(' Build Arguments '.center(100, '='))
for k, v in vars(args).items():
logger.info(f' - {k.ljust(40, ".")}: {v}')
logger.info('=' * 100)
if args.gather_all_token_logits:
args.gather_context_logits = True
args.gather_generation_logits = True
return args
def build_rank_engine(
builder: Builder,
builder_config: tensorrt_llm.builder.BuilderConfig,
engine_name: str,
rank: int,
args: argparse.Namespace,
):
'''
@brief: Build the engine on the given rank.
@param rank: The rank to build the engine.
@param args: The cmd line arguments.
@return: The built engine.
'''
args.mapping = Mapping(
world_size=args.world_size,
rank=rank,
tp_size=args.tp_size,
pp_size=args.pp_size,
)
assert args.num_layers % args.pp_size == 0, \
f"num_layers {args.num_layers} must be a multiple of PP size {args.pp_size}"
profiler.print_memory_usage(f'Rank {rank} Engine build starts')
# Initialize Module
trtllm_model = ChatGLMHeadModel(
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
apply_residual_connection_post_layernorm=args.
apply_residual_connection_post_layernorm,
dtype=args.dtype,
enable_debug_output=args.enable_debug_output,
ffn_hidden_size=args.ffn_hidden_size,
hidden_act=args.hidden_act,
hidden_size=args.hidden_size,
linear_bias=args.linear_bias,
logits_dtype=args.logits_dtype,
mapping=args.mapping,
max_batch_size=args.max_batch_size,
max_beam_width=args.max_beam_width,
max_input_len=args.max_input_len,
max_num_tokens=args.max_num_tokens,
max_output_len=args.max_output_len,
max_seq_length=args.max_seq_length,
model_name=args.model_name,
norm_epsilon=args.norm_epsilon,
num_heads=args.num_heads,
num_kv_heads=args.num_kv_heads,
num_layers=args.num_layers,
qkv_bias=args.qkv_bias,
quant_mode=args.quant_mode,
rmsnorm=args.rmsnorm,
rotary_embedding_scaling=args.rotary_embedding_scaling,
tokens_per_block=args.tokens_per_block,
use_cache=args.use_cache,
vocab_size=args.vocab_size,
)
quantize_kwargs = {}
if args.use_smooth_quant or args.use_weight_only:
if args.weight_only_precision == 'int4_awq':
quantize_kwargs = {
"group_size": args.group_size,
"zero": False,
"pre_quant_scale": True,
"exclude_modules": [],
}
elif args.weight_only_precision == 'int4_gptq':
quantize_kwargs = {
"group_size": args.group_size,
"zero": True,
"pre_quant_scale": False,
}
elif args.enable_fp8 or args.fp8_kv_cache:
logger.info(
f'Loading scaling factors from {args.quantized_fp8_model_path}')
quant_scales = get_scaling_factors(
args.quantized_fp8_model_path,
num_layers=args.num_layers,
quant_mode=args.quant_mode,
)
quantize_kwargs = {"quant_scales": quant_scales}
trtllm_model = quantize_model(
trtllm_model,
args.quant_mode,
**quantize_kwargs,
)
if args.per_group: # load from AWQ weights
load_func = load_from_awq if args.weight_only_precision == 'int4_awq' else load_from_gptq
load_func(
trtllm_model,
args.quant_ckpt_path,
mapping=args.mapping,
dtype=args.dtype,
model_name=args.model_name,
)
elif args.use_smooth_quant:
load_from_sq(
trtllm_model,
args.model_dir,
mapping=args.mapping,
model_name=args.model_name,
)
else: # load from original model
load_from_hf(
trtllm_model,
args.model_dir,
mapping=args.mapping,
dtype=args.dtype,
model_name=args.model_name,
)
profiler.print_memory_usage(f'Rank {rank} model weight loaded.')
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
if args.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
if not args.enable_fp8:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
else:
logger.info("FP8 is not supported with Gemm plugin")
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=args.use_layernorm_plugin)
if args.use_rmsnorm_plugin:
network.plugin_config.set_rmsnorm_plugin(dtype=args.use_rmsnorm_plugin)
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
if args.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.enable_context_fmha_fp32_acc:
network.plugin_config.set_context_fmha(
ContextFMHAType.enabled_with_fp32_acc)
if args.multi_block_mode:
network.plugin_config.enable_mmha_multi_block_mode()
# Quantization plugins
if args.use_smooth_quant:
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype)
network.plugin_config.set_quantize_tensor_plugin()
network.plugin_config.set_quantize_per_token_plugin()
if args.use_weight_only and not args.disable_weight_only_quant_plugin:
if args.per_group:
network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
dtype='float16')
else:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype='float16')
if args.world_size > 1:
network.plugin_config.set_nccl_plugin(
args.dtype,
args.use_custom_all_reduce,
)
if args.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if args.paged_kv_cache:
network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
with net_guard(network):
# Prepare
network.set_named_parameters(trtllm_model.named_parameters())
# Forward
inputs = trtllm_model.prepare_inputs(
max_batch_size=args.max_batch_size,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
use_cache=True,
max_beam_width=args.max_beam_width,
gather_context_logits=args.gather_context_logits,
gather_generation_logits=args.gather_generation_logits,
use_custom_all_reduce=args.use_custom_all_reduce,
)
trtllm_model(*inputs)
if args.enable_debug_output:
# mark intermediate nodes' outputs
for k, v in trtllm_model.named_network_outputs():
v = v.trt_tensor
v.name = k
network.trt_network.mark_output(v)
v.dtype = str_dtype_to_trt(args.dtype)
tensorrt_llm.graph_rewriting.optimize(network)
if args.visualize:
model_path = args.output_dir / (args.model_name + '.onnx')
to_onnx(network.trt_network, model_path)
logger.info(f"Export network into {model_path}, skip engine building")
exit()
# Network -> Engine
engine = None
engine = builder.build_engine(network, builder_config)
if rank == 0:
config_path = args.output_dir / 'config.json'
builder.save_config(builder_config, config_path)
return engine
def build(rank, args):
torch.cuda.set_device(rank % args.gpus_per_node)
args.output_dir.mkdir(parents=True, exist_ok=True)
timing_cache = args.timing_cache
builder = Builder()
for cur_rank in range(args.world_size):
# skip other ranks if parallel_build is enabled
if args.parallel_build and cur_rank != rank:
continue
# NOTE: int8 flag is required to be true when INT8 tensors are exposed to TRT
# TRT-LLM has INT8 I/O when act/weights are quantized without group-scaling (AWQ, GPTQ)
# OR INT8 KV cache is set to contiguous (without paged KV cache enabled).
int8_trt_flag = (args.quant_mode.has_act_or_weight_quant()
and not args.quant_mode.has_per_group_scaling()) or (
not args.paged_kv_cache
and args.quant_mode.has_int8_kv_cache())
builder_config = builder.create_builder_config(
precision=args.dtype,
timing_cache=timing_cache,
profiling_verbosity=args.profiling_verbosity,
tensor_parallel=args.tp_size,
pipeline_parallel=args.pp_size,
int8=int8_trt_flag,
fp8=args.enable_fp8,
strongly_typed=args.strongly_typed,
opt_level=args.builder_opt,
hardware_compatibility=None,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
gather_context_logits=args.gather_context_logits,
gather_generation_logits=args.gather_generation_logits,
hidden_act=args.hidden_act,
hidden_size=args.hidden_size,
max_batch_size=args.max_batch_size,
max_beam_width=args.max_beam_width,
max_input_len=args.max_input_len,
max_num_tokens=args.max_num_tokens,
max_output_len=args.max_output_len,
max_position_embeddings=args.max_seq_length,
multi_query_mode=args.multi_query_mode,
name=args.model_name,
num_heads=args.num_heads,
num_kv_heads=args.num_kv_heads,
num_layers=args.num_layers,
paged_kv_cache=args.paged_kv_cache,
parallel_build=args.parallel_build,
quant_mode=args.quant_mode,
remove_input_padding=args.remove_input_padding,
vocab_size=args.vocab_size,
)
engine_name = get_engine_name(
args.model_name,
args.dtype,
args.world_size,
args.pp_size,
cur_rank,
)
engine = build_rank_engine(
builder,
builder_config,
engine_name,
cur_rank,
args,
)
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
local_num_kv_heads = (args.num_kv_heads + args.world_size -
1) // args.world_size
kv_dtype = str_dtype_to_trt(args.dtype)
if args.quant_mode.has_int8_kv_cache():
kv_dtype = str_dtype_to_trt('int8')
elif args.quant_mode.has_fp8_kv_cache():
kv_dtype = str_dtype_to_trt('fp8')
profiler.check_gpt_mem_usage(
engine=engine,
kv_dtype=kv_dtype,
use_gpt_attention_plugin=args.use_gpt_attention_plugin,
paged_kv_cache=args.paged_kv_cache,
max_batch_size=args.max_batch_size,
max_beam_width=args.max_beam_width,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
local_num_kv_heads=local_num_kv_heads,
head_size=args.hidden_size // args.num_heads,
num_layers=args.num_layers)
if cur_rank == 0:
# Use in-memory timing cache for multiple builder passes.
if not args.parallel_build:
timing_cache = builder_config.trt_builder_config.get_timing_cache(
)
serialize_engine(engine, args.output_dir / engine_name)
del engine
profiler.print_memory_usage(f'Rank {cur_rank} Engine serialized')
if rank == 0:
ok = builder.save_timing_cache(builder_config, args.timing_cache)
assert ok, "Failed to save timing cache."
def run_build(args=None):
args = parse_arguments(args)
if args.random_seed is not None:
torch.manual_seed(args.random_seed)
tik = time.time()
if args.parallel_build and args.world_size > 1 and \
torch.cuda.device_count() >= args.world_size:
logger.warning(
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
)
mp.spawn(build, nprocs=args.world_size, args=(args, ))
else:
args.parallel_build = False
logger.info('Serially build TensorRT engines.')
build(0, args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building all {args.world_size} engines: {t}')
if __name__ == '__main__':
run_build()