[코드 공유] Beam Search 구현 #53
shjas94
started this conversation in
Show and tell
Replies: 1 comment
-
def forward(
self, src, text, is_train=True, batch_max_length=50, teacher_forcing_ratio=1.0
):
if is_train and random.random() < teacher_forcing_ratio:
# teacher forcing시에
tgt = self.text_embedding(text)
tgt = self.pos_encoder(tgt)
tgt_mask = self.pad_mask(text) | self.order_mask(text.size(1))
for layer in self.attention_layers:
tgt = layer(tgt, None, src, tgt_mask)
out = self.generator(tgt)
else:
print(src.shape)
num_steps = batch_max_length - 1
temp_tar = [
[torch.LongTensor(src.size(0)).fill_(self.st_id).unsqueeze(1).to(device), torch.LongTensor(src.size(0)).fill_(0).to(device), [None] * self.layer_num, []]]
k = self.k
for t in range(num_steps):
new_tar = []
for i, tar in enumerate(temp_tar):
target = tar[0][:,-1].unsqueeze(1).to(device)
pre_prob = tar[1]
pre_feature = tar[2]
temp_out = tar[3]
tgt = self.text_embedding(target)
tgt = self.pos_encoder(tgt, point=t)
tgt_mask = self.order_mask(t + 1)
tgt_mask = tgt_mask[:, -1].unsqueeze(1) # [1, (l+1)]
for l, layer in enumerate(self.attention_layers):
tgt = layer(tgt, pre_feature[l], src, tgt_mask)
pre_feature[l] = (
tgt if pre_feature[l] == None else torch.cat([pre_feature[l], tgt], 1)
)
_out = self.generator(tgt)
temp_out.append(_out)
prob = torch.topk(_out[:, -1:, :], k=k, dim=-1)[0].squeeze().transpose(0,1)
idx = torch.topk(_out[:, -1:, :], k=k, dim=-1)[1].squeeze().transpose(0,1)
for i in range(len(idx)):
new_tar.append([torch.stack([tar[0].squeeze().to(device), idx[i].to(device)],dim=1), (prob[i] + pre_prob)/2, pre_feature, temp_out])
# 이제 new_tar안에는 [seq,확률] 들이 저장되어있다. 확률값들을 비교하여 2개만 남긴다음에 나머지 삭제
sorted_tar = sorted(new_tar, key=lambda x: x[1],reverse=True)
temp_tar = [sorted_tar[:k]]
out = sorted(temp_tar,key=lambda x:x[1],reverse=True)[0][-1] # [b, max length, 1, class length]
out = torch.stack(out,dim=1).to(device)
out = out.squeeze(2)
return out batch를 해결하지 않은 초기 prototype |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
빔서치 구현입니다. TransformerDecoder class의 forward함수만 요걸로 교체해주시면 됩니다. 혹시 제출 파일 있는분들은 적용해서 제출하고 결과 공유해주시면 감사하겠습니다.
Beta Was this translation helpful? Give feedback.
All reactions