Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
christophmluscher committed Dec 18, 2024
1 parent e9f9e3a commit 176b2f7
Showing 1 changed file with 27 additions and 28 deletions.
55 changes: 27 additions & 28 deletions i6_models/losses/nce.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,37 @@
__all__ = ["NoiseContrastiveEstimationLossV1Config", "NoiseContrastiveEstimationLossV1"]
__all__ = [
"NoiseContrastiveEstimationLossV1",
]

from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional
import math

from i6_models.config import ModelConfiguration


@dataclass
class NoiseContrastiveEstimationLossV1Config(ModelConfiguration):
num_samples: int
model: nn.Module
noise_distribution_sampler: nn.Module # torch.utils.data.Sampler
log_norm_term: Optional[float] = None


class NoiseContrastiveEstimationLossV1(nn.Module): # (nn.modules.loss._Loss):
class NoiseContrastiveEstimationLossV1(nn.Module):
__constants__ = ["num_samples", "log_norm_term", "reduction"]
num_samples: int
log_norm_term: float

def __init__(self, cfg: NoiseContrastiveEstimationLossV1Config) -> None:
def __init__(
self,
num_samples: int,
model: nn.Module,
noise_distribution_sampler: nn.Module,
log_norm_term: Optional[float] = None,
reduction: str = "none",
) -> None:
super().__init__()
self.num_samples = cfg.num_samples
self.model = cfg.model # only used to access weights of output layer for NCE computation
self.noise_distribution_sampler = cfg.noise_distribution_sampler
self.log_norm_term = cfg.log_norm_term

self._bce = nn.BCEWithLogitsLoss(reduction="none")
self.num_samples = num_samples
self.model = model # only used to access weights of output layer for NCE computation
self.noise_distribution_sampler = noise_distribution_sampler
self.log_norm_term = log_norm_term

self._bce = nn.BCEWithLogitsLoss(reduction=reduction)

def forward(self, data_tensor: torch.Tensor, target: torch.Tensor):
def forward(self, data: torch.Tensor, target: torch.Tensor):
# input: [B x T, F] target: [B x T]

with torch.no_grad():
Expand All @@ -57,19 +56,19 @@ def forward(self, data_tensor: torch.Tensor, target: torch.Tensor):
all_b = F.embedding(all_classes, torch.unsqueeze(self.model.output.bias, 1)) # [B X T + num_sampled, 1]

# slice embeddings for targets and samples below
true_emb = torch.narrow(all_emb, 0, 0, data_tensor.shape[0]) # [B x T, F]
true_b = torch.narrow(all_b, 0, 0, data_tensor.shape[0]) # [B x T, 1]
true_emb = torch.narrow(all_emb, 0, 0, data.shape[0]) # [B x T, F]
true_b = torch.narrow(all_b, 0, 0, data.shape[0]) # [B x T, 1]

sampled_emb = torch.narrow(all_emb, 0, data_tensor.shape[0], self.num_samples) # [num_sampled, F]
sampled_b = torch.narrow(all_b, 0, data_tensor.shape[0], self.num_samples).squeeze(
sampled_emb = torch.narrow(all_emb, 0, data.shape[0], self.num_samples) # [num_sampled, F]
sampled_b = torch.narrow(all_b, 0, data.shape[0], self.num_samples).squeeze(
1
) # [num_sampled], remove dim for broadcasting

# compute logits log p(w|h)
sampled_logits = torch.matmul(data_tensor, sampled_emb.T) # [B x T, num_sampled]
sampled_logits = torch.matmul(data, sampled_emb.T) # [B x T, num_sampled]

# row-wise dot product
true_logits = torch.multiply(data_tensor, true_emb)
true_logits = torch.multiply(data, true_emb) # [B x T, F]
true_logits = torch.sum(true_logits, 1, keepdim=True) # [B x T, 1]

true_logits += true_b
Expand All @@ -80,8 +79,8 @@ def forward(self, data_tensor: torch.Tensor, target: torch.Tensor):
true_logits -= self.log_norm_term
sampled_logits -= self.log_norm_term

true_logits -= torch.log(true_sample_prob.unsqueeze(1))
sampled_logits -= torch.log(sampled_prob.unsqueeze(0))
true_logits -= true_sample_prob.unsqueeze(1)
sampled_logits -= sampled_prob.unsqueeze(0)

out_logits = torch.cat((true_logits, sampled_logits), 1) # [B x T, 1 + num_sampled]

Expand Down

0 comments on commit 176b2f7

Please sign in to comment.