You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def main(args):
with open(args.dataset_info, 'rb') as rf:
dataset_info = pickle.load(rf)
MODEL_PATH = '/home/jiangsiyuan/glm-4-9b-chat'
gpt_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
# gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN, add_special_tokens=False)[0]
gpt_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True).to(args.device)
gpt_model.eval()
# 加载checkpoint和iambic模型
checkpoint = torch.load(args.iambic_ckpt, map_location=args.device)
model_args = checkpoint['args']
iambic_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word))
iambic_model.load_state_dict(checkpoint['state_dict'])
iambic_model = iambic_model.to(args.device)
iambic_model.eval()
if args.verbose:
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.iambic_ckpt, checkpoint['epoch']))
print('iambic model num params', num_params(iambic_model))
with open(args.rhyme_info, 'rb') as rf:
rhyme_info = pickle.load(rf)
checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device)
model_args = checkpoint['args']
rhyme_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word), rhyme_group_size=len(rhyme_info.index2rhyme_group), verbose=args.verbose) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
rhyme_model.load_state_dict(checkpoint['state_dict'])
rhyme_model = rhyme_model.to(args.device)
rhyme_model.eval()
if args.verbose:
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.rhyme_ckpt, checkpoint['epoch']))
print('rhyme model num params', num_params(rhyme_model))
checkpoint = torch.load(args.newline_ckpt, map_location=args.device)
model_args = checkpoint['args']
newline_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
newline_model.load_state_dict(checkpoint['state_dict'])
newline_model = newline_model.to(args.device)
newline_model.eval()
if args.verbose:
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.newline_ckpt, checkpoint['epoch']))
print('iambic model num params', num_params(newline_model))
with open(args.prefix_file, 'r') as rf:
lines = rf.readlines()
for line in tqdm(lines, total=len(lines)):
couplet = predict_couplet(gpt_model,
gpt_tokenizer,
iambic_model,
rhyme_model,
newline_model,
[line],
dataset_info,
rhyme_info,
args.precondition_topk,
args.topk,
condition_lambda=args.condition_lambda,
device=args.device)
assert len(couplet) == 2
print(couplet[1].strip().replace('\n', ''))
if __name__=='__main__':
parser = ArgumentParser()
print(1111111)
#DATA
parser.add_argument('--iambic_ckpt', type=str, default='ckpt/poetry/iambic_predictor/model.pth.tar')
parser.add_argument('--rhyme_ckpt', type=str, default='ckpt/poetry/rhyme_predictor/model.pth.tar')
parser.add_argument('--newline_ckpt', type=str, default='ckpt/poetry/newline_predictor/model.pth.tar')
parser.add_argument('--dataset_info', type=str, help='saved dataset info', default='ckpt/poetry/rhyme_predictor/dataset_info')
parser.add_argument('--rhyme_info', type=str, help='saved rhyme info', default='ckpt/poetry/rhyme_predictor/rhyme_info')
parser.add_argument('--model_string', type=str, default='/home/jiangsiyuan/glm-4-9b-chat')
parser.add_argument('--prefix_file', type=str, default='poetry_data/couplet_prefixes.txt', help='file of prefix lines for couplets')
parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
parser.add_argument('--seed', type=int, default=1, help='random seed')
parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
parser.add_argument('--debug', action='store_true', default=False)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
main(args)
求大佬帮助!谢谢
The text was updated successfully, but these errors were encountered:
Enermy
changed the title
在使用FUDGE时出现报错:RuntimeError: Error(s) in loading state_dict for Model: size mismatch for gpt_embed.weight: copying a param with shape torch.Size([50258, 300]) from checkpoint, the shape in current model is torch.Size([151344, 300]).RuntimeError: Error(s) in loading state_dict for Model: size mismatch for gpt_embed.weight: copying a param with shape torch.Size([50258, 300]) from checkpoint, the shape in current model is torch.Size([151344, 300]).
在使用FUDGE时出现报错:RuntimeError: Error(s) in loading state_dict for Model:
Oct 22, 2024
使用FUDGE(https://github.com/[yangkevin2/naacl-2021-fudge-controlled-generation](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation?tab=readme-ov-file)?tab=readme-ov-file)对glm4进行押韵诗歌生成的时候遇到一个问题:
完整代码如下:
求大佬帮助!谢谢
The text was updated successfully, but these errors were encountered: