Skip to content
This repository has been archived by the owner on Nov 23, 2023. It is now read-only.

Commit

Permalink
model comments, docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
victor-shepardson committed Mar 30, 2022
1 parent f6bea11 commit 7e141b5
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions notepredictor/notepredictor/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,19 @@ def forward(self, x):
# return self.net(x)

class ModalityTransformer(nn.Module):
"""Model joint distribution of modalities autoregressively with random permutations"""
"""
Model joint distribution of note modalities (e.g. pitch, time, velocity).
This is an autoregressive Transformer model for the *internal* structure of notes.
It is *not* autoregressive in time, but in modality.
At training time, it executes in parallel over all timesteps and modalities, with
time dependencies provided via the RNN backbone.
At sampling time it is called serially, one modality at a time,
repeatedly at each time step.
Inspired by XLNet: http://arxiv.org/abs/1906.08237
"""
def __init__(self, input_size, hidden_size, heads=4, layers=1):
super().__init__()
self.net = nn.TransformerDecoder(
Expand All @@ -95,13 +107,11 @@ def forward(self, ctx, h_ctx, h_tgt):
ctx: list of Tensor[batch x time x input_size], length note_dim-1
these are the embedded ground truth values
h_ctx: Tensor[batch x time x input_size]
(need something to attend to when ctx is empty)
projection of RNN state (need something to attend to when ctx is empty)
h_tgt: list of Tensor[batch x time x input_size], length note_dim
these are projections of the RNN state
these are projections of the RNN state for each target,
which the Transformer will map to distribution parameters.
"""
# h_tgt = list(h_tgt)
# ctx = list(ctx)

# explicitly broadcast
h_ctx, *ctx = torch.broadcast_tensors(h_ctx, *ctx)
h_ctx, *h_tgt = torch.broadcast_tensors(h_ctx, *h_tgt)
Expand All @@ -122,6 +132,7 @@ def forward(self, ctx, h_ctx, h_tgt):

# generate a mask
# this is both the target and memory mask
# masking is such that each target can only depend on "previous" context
n = len(h_tgt)
mask = ~tgt.new_ones((n,n), dtype=bool).tril()

Expand Down Expand Up @@ -254,7 +265,7 @@ def embeddings(self):

def forward(self, pitches, times, velocities, validation=False):
"""
teacher-forced probabilistic loss and diagnostics for training
teacher-forced probabilistic loss and diagnostics for training.
Args:
pitches: LongTensor[batch, time]
Expand All @@ -263,33 +274,41 @@ def forward(self, pitches, times, velocities, validation=False):
"""
batch_size, batch_len = pitches.shape

# embed data to input vectors
pitch_emb = self.pitch_emb(pitches) # batch, time, emb_size
time_emb = self.time_emb(times) # batch, time, emb_size
vel_emb = self.vel_emb(velocities) # batch, time, emb_size

embs = (pitch_emb, time_emb, vel_emb)

# feed to RNN backbone
x = torch.cat(embs, -1)[:,:-1] # skip last time position
## broadcast initial state to batch size
initial_state = tuple(
t.expand(self.rnn.num_layers, x.shape[0], -1).contiguous() # 1 x batch x hidden
for t in self.initial_state)
h, _ = self.rnn(x, initial_state) #batch, time, hidden_size

# fit all note factorizations at once.
# fit all note factorizations (e.g. pitch->time->vel vs vel->time->pitch)
# TODO: perm each batch item independently?
# get a random ordering for note modalities:
perm = torch.randperm(self.note_dim)
# chunk RNN state into Transformer inputs
hs = list(self.h_proj(h).chunk(self.note_dim+1, -1))
h_ctx = hs[0]
h_tgt = [hs[i+1] for i in perm]
# embed ground truth values for teacher-forcing
embs = [embs[i][:,1:] for i in perm[:-1]]
# run through Transformer to conditional hidden states
mode_hs = self.xformer(embs, h_ctx, h_tgt)
# permute back to canonical order
mode_hs = [mode_hs[i] for i in perm.argsort()]

# final projections to raw distribution parameters
pitch_params, time_params, vel_params = [
proj(h) for proj,h in zip(self.projections, mode_hs)]

# get likelihoods
# get likelihoods of data for each modality
pitch_logits = F.log_softmax(pitch_params, -1)
pitch_targets = pitches[:,1:,None] #batch, time, 1
pitch_log_probs = pitch_logits.gather(-1, pitch_targets)[...,0]
Expand All @@ -309,6 +328,8 @@ def forward(self, pitches, times, velocities, validation=False):
**{'time_'+k:v for k,v in time_result.items()},
**{'velocity_'+k:v for k,v in vel_result.items()}
}
# this just computes some extra diagnostics which are inconvenient to do in the
# training script. should be turned off during training for performance.
if validation:
with torch.no_grad():
r['time_acc_30ms'] = (
Expand Down

0 comments on commit 7e141b5

Please sign in to comment.