Skip to content

Commit

Permalink
Add the export model process in mlperf codes (#1602)
Browse files Browse the repository at this point in the history
Signed-off-by: YIYANGCAI <[email protected]>
  • Loading branch information
YIYANGCAI authored Feb 23, 2024
1 parent e22c61e commit 354791d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def forward(self, *inp, **kwargs):
parser.add_argument('--use_max_length', action='store_true',
help='Only select data whose length equals or more than model.seqlen, please refer to GPTQ original implementation'
)
parser.add_argument('--benchmark', action='store_true', help='Whether to do benchmark on CNN datasets.')

# load the gptj model
args = parser.parse_args()
Expand Down Expand Up @@ -324,12 +325,13 @@ def forward(self, *inp, **kwargs):

q_model = quantization.fit(model, conf, calib_dataloader=dataloader,)

q_model.save("./gptj-gptq-gs128-calib128-calibration-fp16/")
# q_model.save("./gptj-gptq-gs128-calib128-calibration-fp16/")
# q_model.float()
# q_model.save("./gptj-gptq-gs128-calib128-calibration-fp32/")
compressed_model = q_model.export_compressed_model()
torch.save(compressed_model.state_dict(), "gptj_w3g128_compressed_model.pt")
# benchmarking first 100 examples
# if args.benchmark:
if True:
if args.benchmark:
# use half to accerlerate inference
model.half()
model = model.to(DEV)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ CALIBRATION_DATA=/your/data/calibration-data/cnn_dailymail_calibration.json
VALIDATION_DATA=/your/data/validation-data/cnn_dailymail_validation.json
MODEL_DIR=/your/gptj/

python -u examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py \
python -u examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_gptj_mlperf_int4.py \
--model_name_or_path ${MODEL_DIR} \
--wbits 4 \
--wbits 3 \
--sym \
--group_size -1 \
--nsamples 128 \
--group_size 128 \
--nsamples 256 \
--calib-data-path ${CALIBRATION_DATA} \
--val-data-path ${VALIDATION_DATA} \
--calib-iters 128 \
--calib-iters 256 \
--use_max_length \
--pad_max_length 2048 \
--use_gpu
--use_gpu

0 comments on commit 354791d

Please sign in to comment.