forked from theophilegervet/learner-performance-prediction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_dkt2.py
62 lines (49 loc) · 2.57 KB
/
model_dkt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn as nn
class DKT2(nn.Module):
def __init__(self, num_items, num_skills, hid_size, embed_size, num_hid_layers, drop_prob):
"""Deep Knowledge Tracing (https://papers.nips.cc/paper/5654-deep-knowledge-tracing.pdf)
with some changes inspired by
Deep Hierarchical Knowledge Tracing (https://arxiv.org/pdf/1908.02146.pdf).
Arguments:
num_items (int): number of items
num_skills (int): number of knowledge points
hid_size (int): hidden layer dimension
embed_size (int): query embedding dimension
num_hid_layers (int): number of hidden layers
drop_prob (float): dropout probability
"""
super(DKT2, self).__init__()
self.embed_size = embed_size
self.item_embeds = nn.Embedding(num_items + 1, embed_size // 2, padding_idx=0)
self.skill_embeds = nn.Embedding(num_skills + 1, embed_size // 2, padding_idx=0)
self.lstm = nn.LSTM(2 * embed_size, hid_size, num_hid_layers, batch_first=True)
self.dropout = nn.Dropout(p=drop_prob)
self.lin1 = nn.Linear(hid_size + embed_size, hid_size)
self.lin2 = nn.Linear(hid_size, 1)
self.hid_size = hid_size
self.num_hid_layers = num_hid_layers
def forward(self, item_inputs, skill_inputs, label_inputs, item_ids, skill_ids, hidden):
inputs = self.get_inputs(item_inputs, skill_inputs, label_inputs)
query = self.get_query(item_ids, skill_ids)
x, hidden = self.lstm(inputs, hidden)
x = self.lin1(torch.cat([self.dropout(x), query], dim=-1))
x = self.lin2(torch.relu(self.dropout(x))).squeeze(-1)
return x, hidden
def get_inputs(self, item_inputs, skill_inputs, label_inputs):
item_inputs = self.item_embeds(item_inputs)
skill_inputs = self.skill_embeds(skill_inputs)
label_inputs = label_inputs.unsqueeze(-1).float()
inputs = torch.cat([item_inputs, skill_inputs, item_inputs, skill_inputs], dim=-1)
inputs[..., :self.embed_size] *= label_inputs
inputs[..., self.embed_size:] *= 1 - label_inputs
return inputs
def get_query(self, item_ids, skill_ids):
item_ids = self.item_embeds(item_ids)
skill_ids = self.skill_embeds(skill_ids)
query = torch.cat([item_ids, skill_ids], dim=-1)
return query
def init_hidden(self, bsz):
weight = next(self.parameters())
return (weight.new_zeros(self.num_hid_layers, bsz, self.hid_size),
weight.new_zeros(self.num_hid_layers, bsz, self.hid_size))