From 176b2f7e6354e9ebaa721b21496778d19820e812 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20M=2E=20L=C3=BCscher?= Date: Wed, 18 Dec 2024 18:15:36 +0100 Subject: [PATCH] refactor --- i6_models/losses/nce.py | 55 ++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/i6_models/losses/nce.py b/i6_models/losses/nce.py index 4522c27a..d55e0d58 100644 --- a/i6_models/losses/nce.py +++ b/i6_models/losses/nce.py @@ -1,38 +1,37 @@ -__all__ = ["NoiseContrastiveEstimationLossV1Config", "NoiseContrastiveEstimationLossV1"] +__all__ = [ + "NoiseContrastiveEstimationLossV1", +] -from dataclasses import dataclass import torch from torch import nn from torch.nn import functional as F from typing import Optional import math -from i6_models.config import ModelConfiguration - -@dataclass -class NoiseContrastiveEstimationLossV1Config(ModelConfiguration): - num_samples: int - model: nn.Module - noise_distribution_sampler: nn.Module # torch.utils.data.Sampler - log_norm_term: Optional[float] = None - - -class NoiseContrastiveEstimationLossV1(nn.Module): # (nn.modules.loss._Loss): +class NoiseContrastiveEstimationLossV1(nn.Module): __constants__ = ["num_samples", "log_norm_term", "reduction"] num_samples: int log_norm_term: float - def __init__(self, cfg: NoiseContrastiveEstimationLossV1Config) -> None: + def __init__( + self, + num_samples: int, + model: nn.Module, + noise_distribution_sampler: nn.Module, + log_norm_term: Optional[float] = None, + reduction: str = "none", + ) -> None: super().__init__() - self.num_samples = cfg.num_samples - self.model = cfg.model # only used to access weights of output layer for NCE computation - self.noise_distribution_sampler = cfg.noise_distribution_sampler - self.log_norm_term = cfg.log_norm_term - self._bce = nn.BCEWithLogitsLoss(reduction="none") + self.num_samples = num_samples + self.model = model # only used to access weights of output layer for NCE computation + self.noise_distribution_sampler = noise_distribution_sampler + self.log_norm_term = log_norm_term + + self._bce = nn.BCEWithLogitsLoss(reduction=reduction) - def forward(self, data_tensor: torch.Tensor, target: torch.Tensor): + def forward(self, data: torch.Tensor, target: torch.Tensor): # input: [B x T, F] target: [B x T] with torch.no_grad(): @@ -57,19 +56,19 @@ def forward(self, data_tensor: torch.Tensor, target: torch.Tensor): all_b = F.embedding(all_classes, torch.unsqueeze(self.model.output.bias, 1)) # [B X T + num_sampled, 1] # slice embeddings for targets and samples below - true_emb = torch.narrow(all_emb, 0, 0, data_tensor.shape[0]) # [B x T, F] - true_b = torch.narrow(all_b, 0, 0, data_tensor.shape[0]) # [B x T, 1] + true_emb = torch.narrow(all_emb, 0, 0, data.shape[0]) # [B x T, F] + true_b = torch.narrow(all_b, 0, 0, data.shape[0]) # [B x T, 1] - sampled_emb = torch.narrow(all_emb, 0, data_tensor.shape[0], self.num_samples) # [num_sampled, F] - sampled_b = torch.narrow(all_b, 0, data_tensor.shape[0], self.num_samples).squeeze( + sampled_emb = torch.narrow(all_emb, 0, data.shape[0], self.num_samples) # [num_sampled, F] + sampled_b = torch.narrow(all_b, 0, data.shape[0], self.num_samples).squeeze( 1 ) # [num_sampled], remove dim for broadcasting # compute logits log p(w|h) - sampled_logits = torch.matmul(data_tensor, sampled_emb.T) # [B x T, num_sampled] + sampled_logits = torch.matmul(data, sampled_emb.T) # [B x T, num_sampled] # row-wise dot product - true_logits = torch.multiply(data_tensor, true_emb) + true_logits = torch.multiply(data, true_emb) # [B x T, F] true_logits = torch.sum(true_logits, 1, keepdim=True) # [B x T, 1] true_logits += true_b @@ -80,8 +79,8 @@ def forward(self, data_tensor: torch.Tensor, target: torch.Tensor): true_logits -= self.log_norm_term sampled_logits -= self.log_norm_term - true_logits -= torch.log(true_sample_prob.unsqueeze(1)) - sampled_logits -= torch.log(sampled_prob.unsqueeze(0)) + true_logits -= true_sample_prob.unsqueeze(1) + sampled_logits -= sampled_prob.unsqueeze(0) out_logits = torch.cat((true_logits, sampled_logits), 1) # [B x T, 1 + num_sampled]