Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Latest commit

 

History

History
40 lines (28 loc) · 1.39 KB

README.zh-CN.md

File metadata and controls

40 lines (28 loc) · 1.39 KB

Keras GPT-2

Version License

[中文|English]

GPT-2预训练权重加载和预测。

安装

pip install keras-gpt-2

示例

import os
from keras_gpt_2 import load_trained_model_from_checkpoint, get_bpe_from_files, generate


model_folder = 'xxx/yyy/117M'
config_path = os.path.join(model_folder, 'hparams.json')
checkpoint_path = os.path.join(model_folder, 'model.ckpt')
encoder_path = os.path.join(model_folder, 'encoder.json')
vocab_path = os.path.join(model_folder, 'vocab.bpe')


print('从预训练checkpoint加载模型……')
model = load_trained_model_from_checkpoint(config_path, checkpoint_path)
print('读取BPE词典……')
bpe = get_bpe_from_files(encoder_path, vocab_path)
print('生成文本……')
output = generate(model, bpe, ['From the day forth, my arm'], length=20, top_k=1)

# 在使用117M模型且top_k设置为1的情况下,输出为:
# "From the day forth, my arm was broken, and I was in a state of pain. I was in a state of pain,"
print(output[0])