From cac3cf3a68e271a4fb1189c9fbdc6f1f4d51ab3b Mon Sep 17 00:00:00 2001
From: hedi7 <heini89@gmail.com>
Date: Thu, 5 Sep 2019 21:40:23 +0800
Subject: [PATCH] Updated to use ignite framework, simplifies a lot of stats
 calculation, enhanced loss to include a weighted average to enhance training
 stability. Removed tanh from attention, which further enhances stability.

---
 README.md                 |  46 +--
 dataset.py                |  80 ++--
 losses.py                 |  85 ++--
 models.py                 | 109 +++--
 requirements.txt          |  35 +-
 run.py                    | 851 +++++++++++++++++---------------------
 show_most_common_words.py |  78 ++--
 7 files changed, 632 insertions(+), 652 deletions(-)

diff --git a/README.md b/README.md
index d546c47..f14954b 100644
--- a/README.md
+++ b/README.md
@@ -8,24 +8,27 @@ Source code for the paper [Text-based Depression Detection: What Triggers An Ale
 The required python packages can be found in `requirements.txt`.
 
 ```
-tqdm==4.24.0
-allennlp==0.8.2
-tableprint==0.8.0
-tabulate==0.8.2
-fire==0.1.3
-nltk==3.3
+torch==1.2.0
 kaldi_io==0.9.1
-scipy==1.2.1
-torchnet==0.0.4
-pandas==0.24.1
-numpy==1.16.2
-bert_serving_client==1.8.3
-imbalanced_learn==0.4.3
-torch==0.4.1.post2
-gensim==3.7.1
+bert_serving_server==1.9.6
+pytorch_ignite==0.2.0
+numpy==1.16.4
+librosa==0.7.0
+tabulate==0.8.3
+mistletoe==0.7.2
+scipy==1.3.0
+tqdm==4.32.2
+pandas==0.24.2
+fire==0.1.3
+imbalanced_learn==0.5.0
+allennlp==0.8.5
+gensim==3.8.0
+ignite==1.1.0
 imblearn==0.0
-scikit_learn==0.20.3
-PyYAML==5.1
+nltk==3.4.5
+plotnine==0.6.0
+scikit_learn==0.21.3
+PyYAML==5.1.2
 ```
 
 
@@ -70,14 +73,9 @@ The main script of this repo is `run.py`.
 The code is centered around the config files placed at `config/`. Each parameter in these files can be modified for each run using google-fire e.g., if one opts to run a different model, just pass `--model GRU`. 
 
 `run.py` the following options ( ran as `python run.py OPTION`):
-* `train`: Trains a model given a config file (default is `config/text_lstm_deep.yaml`)
-* `stats`: Prints the evaluation results on the development set
-* `trainstats`: Convenience function to run train and evaluate in one.
-* `search`: Parameter search for learning rate, momentum and nesterov (SGD)
-* `searchadam`: Learning rate search for adam optimizer
-* `ex`: Extracts features from a given network (not finished),
-* `fwd`: Debugging function to forward some features through a network
-* `fuse`: Fusion of two models or more. Fusion is done by averaging each output.
+* `train`: Trains a model given a config file (default is `config/text_lstm_deep.yaml`).
+* `evaluate`: Evaluates a given trained model directory. Just pass the result of `train` to it.
+* `evaluates`: Same as evaluate but runs multiple passed directories e.g., passed as glob (`experiment/*/*/*`), and returns an outputfile as well as a table report of the results. Useful for multiple runs with different seeds.
 
 ## Notes
 
diff --git a/dataset.py b/dataset.py
index 18ed255..5a0a605 100644
--- a/dataset.py
+++ b/dataset.py
@@ -18,7 +18,6 @@ class ListDataset(torch.utils.data.Dataset):
     Arguments:
         *lists (List): List that have the same size of the first dimension.
     """
-
     def __init__(self, *lists):
         assert all(len(lists[0]) == len(a_list) for a_list in lists)
         self.lists = lists
@@ -30,38 +29,32 @@ def __len__(self):
         return len(self.lists[0])
 
 
-def seq_collate_fn(data_batches):
-    """seq_collate_fn
-
-    Helper function for torch.utils.data.Dataloader
-
-    :param data_batches: iterateable
-    """
-    data_batches.sort(key=lambda x: len(x[0]), reverse=True)
-
-    def merge_seq(dataseq, dim=0):
-        lengths = [seq.shape for seq in dataseq]
-        # Assuming duration is given in the first dimension of each sequence
-        maxlengths = tuple(np.max(lengths, axis=dim))
-
-        # For the case that the lenthts are 2dimensional
-        lengths = np.array(lengths)[:, dim]
-        # batch_mean = np.mean(np.concatenate(dataseq),axis=0, keepdims=True)
-        # padded = np.tile(batch_mean, (len(dataseq), maxlengths[0], 1))
-        padded = np.zeros((len(dataseq),) + maxlengths)
-        for i, seq in enumerate(dataseq):
-            end = lengths[i]
-            padded[i, :end] = seq[:end]
-        return padded, lengths
-    features, targets = zip(*data_batches)
-    features_seq, feature_lengths = merge_seq(features)
-    return torch.from_numpy(features_seq), torch.tensor(targets)
-
-
-def create_dataloader(
-        kaldi_string, label_dict, transform=None,
-        batch_size: int = 16, num_workers: int = 1, shuffle: bool = True
-):
+def pad(tensorlist, batch_first=True, padding_value=0.):
+    # In case we have 3d tensor in each element, squeeze the first dim (usually 1)
+    if len(tensorlist[0].shape) == 3:
+        tensorlist = [ten.squeeze() for ten in tensorlist]
+    # In case of len == 1 padding will throw an error
+    if len(tensorlist) == 1:
+        return torch.as_tensor(tensorlist)
+    tensorlist = [torch.as_tensor(item) for item in tensorlist]
+    return torch.nn.utils.rnn.pad_sequence(tensorlist,
+                                           batch_first=batch_first,
+                                           padding_value=padding_value)
+
+
+def sequential_collate(batches):
+    # sort length wise
+    batches.sort(key=lambda x: len(x), reverse=True)
+    features, targets = zip(*batches)
+    return pad(features), torch.as_tensor(targets)
+
+
+def create_dataloader(kaldi_string,
+                      label_dict,
+                      transform=None,
+                      batch_size: int = 16,
+                      num_workers: int = 1,
+                      shuffle: bool = True):
     """create_dataloader
 
     :param kaldi_string: copy-feats input
@@ -83,7 +76,8 @@ def valid_feat(item):
     features = []
     labels = []
     # Directly filter out all utterances without labels
-    for idx, (k, feat) in enumerate(filter(valid_feat, kaldi_io.read_mat_ark(kaldi_string))):
+    for idx, (k, feat) in enumerate(
+            filter(valid_feat, kaldi_io.read_mat_ark(kaldi_string))):
         if transform:
             feat = transform(feat)
         features.append(feat)
@@ -91,17 +85,21 @@ def valid_feat(item):
     assert len(features) > 0, "No features were found, are the labels correct?"
     # Shuffling means that this is training dataset, so oversample
     if shuffle:
-        random_oversampler = RandomOverSampler(random_state=0)
+        sampler = RandomOverSampler()
         # Assume that label is Score, Binary, we take the binary to oversample
         sample_index = 1 if len(labels[0]) == 2 else 0
         # Dummy X data, y is the binary label
-        _, _ = random_oversampler.fit_resample(
-            torch.ones(len(features), 1), [l[sample_index] for l in labels])
+        _, _ = sampler.fit_resample(torch.ones(len(features), 1),
+                                    [l[sample_index] for l in labels])
         # Get the indices for the sampled data
-        indicies = random_oversampler.sample_indices_
+        indicies = sampler.sample_indices_
         # reindex, new data is oversampled in the minority class
-        features, labels = [features[id]
-                            for id in indicies], [labels[id] for id in indicies]
+        features, labels = [features[id] for id in indicies
+                            ], [labels[id] for id in indicies]
     dataset = ListDataset(features, labels)
     # Configure weights to reduce number of unseen utterances
-    return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=seq_collate_fn, shuffle=shuffle)
+    return data.DataLoader(dataset,
+                           batch_size=batch_size,
+                           num_workers=num_workers,
+                           collate_fn=sequential_collate,
+                           shuffle=shuffle)
diff --git a/losses.py b/losses.py
index c54f3a7..1db2086 100644
--- a/losses.py
+++ b/losses.py
@@ -4,9 +4,7 @@
 
 
 class MAELoss(torch.nn.Module):
-
     """Docstring for MAELoss. """
-
     def __init__(self):
         """TODO: to be defined1. """
         torch.nn.Module.__init__(self)
@@ -19,9 +17,7 @@ def forward(self, input, target):
 
 
 class RMSELoss(torch.nn.Module):
-
     """Docstring for RMSELoss. """
-
     def __init__(self):
         """ """
         torch.nn.Module.__init__(self)
@@ -33,10 +29,35 @@ def forward(self, input, target):
         return torch.sqrt(self.loss(input, target))
 
 
-class MSELoss(torch.nn.Module):
+class CosineLoss(torch.nn.Module):
+    """description"""
+    def __init__(self):
+        torch.nn.Module.__init__(self)
+        self.norm = torch.nn.functional.normalize
 
-    """Docstring for MSELoss. """
+    @staticmethod
+    def label_to_onehot(tar, nlabels=2):
+        if tar.ndimension() == 1:
+            tar = tar.unsqueeze(-1)  # add singleton [B, 1]
+        tar_onehot = tar.new_zeros((len(tar), nlabels)).detach()
+        tar_onehot.scatter_(1, tar.long(), 1)
+        return tar_onehot.float()
 
+    def forward(self, input, target):
+        target = CosineLoss.label_to_onehot(target)
+        if input.ndimension() == 2:
+            input = input.unsqueeze(-1)  # add singleton dimension
+        if target.ndimension() == 2:
+            target = target.unsqueeze(1)  # Add singleton dimension
+        norm_input = self.norm(input, p=2, dim=1)
+        #Input shape: [Bx1xC]
+        #Target shape: [BxCx1]
+        cos_loss = 1 - torch.bmm(target, norm_input)
+        return cos_loss.mean()
+
+
+class MSELoss(torch.nn.Module):
+    """Docstring for MSELoss. """
     def __init__(self):
         """ """
         torch.nn.Module.__init__(self)
@@ -49,24 +70,18 @@ def forward(self, input, target):
 
 
 class HuberLoss(torch.nn.Module):
-
     """Docstring for HuberLoss. """
-
     def __init__(self):
         """ """
         torch.nn.Module.__init__(self)
-
         self.loss = torch.nn.SmoothL1Loss()
 
     def forward(self, input, target):
-        target = target.float()
-        return self.loss(input, target)
+        return self.loss(input, target.float())
 
 
 class DepressionLoss(torch.nn.Module):
-
     """Docstring for DepressionLoss. """
-
     def __init__(self):
         """ """
         torch.nn.Module.__init__(self)
@@ -75,45 +90,57 @@ def __init__(self):
         self.bin_loss = BCEWithLogitsLoss()
 
     def forward(self, input, target):
-        return self.score_loss(input[:, 0], target[:, 0]) + self.bin_loss(input[:, 1], target[:, 1])
-
+        return self.score_loss(input[:, 0], target[:, 0]) + self.bin_loss(
+            input[:, 1], target[:, 1])
 
-class DepressionLossMSE(torch.nn.Module):
 
+class DepressionLossSmoothCos(torch.nn.Module):
     """Docstring for DepressionLoss. """
-
     def __init__(self):
         """ """
         torch.nn.Module.__init__(self)
 
-        self.score_loss = MSELoss()
-        self.bin_loss = BCEWithLogitsLoss()
+        self.score_loss = HuberLoss()
+        self.cos_loss = CosineLoss()
 
     def forward(self, input, target):
-        score_loss = self.score_loss(input[:, 0], target[:, 0])
-        binary_loss = self.bin_loss(input[:, 1], target[:, 1])
-        return score_loss + binary_loss
+        target = target.long()
+        phq8_pred, phq8_tar = input[:, 0], target[:, 0]
+        binary_pred, binary_tar = input[:, 1:3], target[:, 1]
+        return self.score_loss(phq8_pred, phq8_tar) + self.cos_loss(
+            binary_pred, binary_tar)
 
 
 class DepressionLossSmooth(torch.nn.Module):
-
     """Docstring for DepressionLoss. """
-
-    def __init__(self):
+    def __init__(self, reduction='sum'):
         """ """
         torch.nn.Module.__init__(self)
 
         self.score_loss = HuberLoss()
-        self.bin_loss = BCEWithLogitsLoss()
+        self.bce = BCEWithLogitsLoss()
+        self.weight = torch.nn.Parameter(torch.tensor(0.))
+        self.reduction = reduction
+        self.eps = 0.01
 
     def forward(self, input, target):
-        return self.score_loss(input[:, 0], target[:, 0]) + self.bin_loss(input[:, 1], target[:, 1])
+        phq8_pred, phq8_tar = input[:, 0], target[:, 0]
+        binary_pred, binary_tar = input[:, 1], target[:, 1]
+        score_loss, bin_loss = self.score_loss(phq8_pred, phq8_tar), self.bce(
+            binary_pred, binary_tar)
+        weight = torch.clamp(torch.sigmoid(self.weight),
+                             min=self.eps,
+                             max=1 - self.eps)
+        stacked_loss = (weight * score_loss) + ((1 - weight) * bin_loss)
+        if self.reduction == 'mean':
+            stacked_loss = stacked_loss.mean()
+        elif self.reduction == 'sum':
+            stacked_loss = stacked_loss.sum()
+        return stacked_loss
 
 
 class BCEWithLogitsLoss(torch.nn.Module):
-
     """Docstring for BCEWithLogitsLoss. """
-
     def __init__(self):
         """TODO: to be defined1. """
         torch.nn.Module.__init__(self)
diff --git a/models.py b/models.py
index 1b51524..e4063f6 100644
--- a/models.py
+++ b/models.py
@@ -1,8 +1,15 @@
 import torch
+import math
 import torch.nn as nn
+from TCN.tcn import TemporalConvNet
 
 
 def init_rnn(rnn):
+    """init_rnn
+    Initialized RNN weights, independent of type GRU/LSTM/RNN
+
+    :param rnn: the rnn model 
+    """
     for name, param in rnn.named_parameters():
         if 'bias' in name:
             nn.init.constant_(param, 0.0)
@@ -10,10 +17,32 @@ def init_rnn(rnn):
             nn.init.xavier_uniform_(param)
 
 
-class LSTM(torch.nn.Module):
+class AutoEncoderLSTM(nn.Module):
+    """docstring for AutoEncoderLSTM"""
+    def __init__(self, inputdim, output_size=None, **kwargs):
+        super(AutoEncoderLSTM, self).__init__()
+        kwargs.setdefault('hidden_size', 128)
+        kwargs.setdefault('num_layers', 3)
+        kwargs.setdefault('bidirectional', True)
+        kwargs.setdefault('dropout', 0.2)
+        self.net = nn.LSTM(inputdim, batch_first=True, **kwargs)
+        self.decoder = nn.LSTM(input_size=kwargs['hidden_size'] *
+                               (int(kwargs['bidirectional']) + 1),
+                               hidden_size=inputdim,
+                               batch_first=True,
+                               bidirectional=kwargs['bidirectional'])
+        self.squeezer = nn.Sequential()
+        if kwargs['bidirectional']:
+            self.squeezer = nn.Linear(inputdim * 2, inputdim)
+
+    def forward(self, x):
+        enc_o, _ = self.net(x)
+        out, _ = self.decoder(enc_o)
+        return self.squeezer(out)
 
-    """LSTM class for Depression detection"""
 
+class LSTM(torch.nn.Module):
+    """LSTM class for Depression detection"""
     def __init__(self, inputdim: int, output_size: int, **kwargs):
         """
 
@@ -27,17 +56,16 @@ def __init__(self, inputdim: int, output_size: int, **kwargs):
         kwargs.setdefault('hidden_size', 128)
         kwargs.setdefault('num_layers', 2)
         kwargs.setdefault('bidirectional', True)
-        kwargs.setdefault('dropout', 0.2)
+        kwargs.setdefault('dropout', 0.1)
         self.net = nn.LSTM(inputdim, batch_first=True, **kwargs)
         init_rnn(self.net)
         rnn_outputdim = self.net(torch.randn(1, 50, inputdim))[0].shape
-        self.outputlayer = nn.Linear(
-            rnn_outputdim[-1], output_size)
+        self.outputlayer = nn.Linear(rnn_outputdim[-1], output_size)
 
-    def forward(self, x):
+    def forward(self, x: torch.tensor):
         """Forwards input vector through network
 
-        :x: TODO
+        :x: torch.tensor
         :returns: TODO
 
         """
@@ -46,9 +74,7 @@ def forward(self, x):
 
 
 class GRU(torch.nn.Module):
-
     """GRU class for Depression detection"""
-
     def __init__(self, inputdim: int, output_size: int, **kwargs):
         """
 
@@ -76,13 +102,11 @@ def forward(self, x):
 
         """
         x, _ = self.net(x)
-        return self.outputlayer(x)
+        return self.outputlayer(x)  # Pool time
 
 
 class GRUAttn(torch.nn.Module):
-
     """GRUAttn class for Depression detection"""
-
     def __init__(self, inputdim: int, output_size: int, **kwargs):
         """
 
@@ -101,7 +125,7 @@ def __init__(self, inputdim: int, output_size: int, **kwargs):
         init_rnn(self.net)
         rnn_outputdim = self.net(torch.randn(1, 50, inputdim))[0].shape
         self.outputlayer = nn.Linear(rnn_outputdim[-1], output_size)
-        self.attn = SimpleAttention(kwargs['hidden_size']*2)
+        self.attn = SimpleAttention(kwargs['hidden_size'] * 2)
 
     def forward(self, x):
         """Forwards input vector through network
@@ -115,10 +139,47 @@ def forward(self, x):
         return self.outputlayer(x)
 
 
-class LSTMAttn(torch.nn.Module):
-
+class LSTMDualAttn(torch.nn.Module):
     """LSTMSimpleAttn class for Depression detection"""
+    def __init__(self, inputdim: int, output_size: int, **kwargs):
+        """
 
+        :inputdim:int: Input dimension
+        :output_size:int: Output dimension of LSTMSimpleAttn
+        :**kwargs: Other args, passed down to nn.LSTMSimpleAttn
+
+
+        """
+        torch.nn.Module.__init__(self)
+        kwargs.setdefault('hidden_size', 128)
+        kwargs.setdefault('num_layers', 3)
+        kwargs.setdefault('bidirectional', True)
+        kwargs.setdefault('dropout', 0.2)
+        self.lstm = nn.LSTM(inputdim, **kwargs)
+        init_rnn(self.lstm)
+        self.outputlayer = nn.Linear(kwargs['hidden_size'] * 2, output_size)
+        nn.init.kaiming_normal_(self.outputlayer.weight)
+        self.attn = nn.Linear(kwargs['hidden_size'] * 2, output_size)
+        nn.init.kaiming_normal_(self.outputlayer.weight)
+        # self.mae_attn = SimpleAttention(kwargs['hidden_size'] * 2, 1)
+        # self.bin_attn = SimpleAttention(kwargs['hidden_size'] * 2, 1)
+
+    def forward(self, x):
+        """Forwards input vector through network
+
+        :x: input tensor of shape (B, T, D) [Batch, Time, Dim]
+        :returns: TODO
+
+        """
+        x, _ = self.lstm(x)
+        out = self.outputlayer(x)
+        time_attn = torch.softmax(self.attn(x), dim=1)
+        pooled = (time_attn * out).sum(dim=1).unsqueeze(1)
+        return pooled
+
+
+class LSTMAttn(torch.nn.Module):
+    """LSTMSimpleAttn class for Depression detection"""
     def __init__(self, inputdim: int, output_size: int, **kwargs):
         """
 
@@ -130,22 +191,22 @@ def __init__(self, inputdim: int, output_size: int, **kwargs):
         """
         torch.nn.Module.__init__(self)
         kwargs.setdefault('hidden_size', 128)
-        kwargs.setdefault('num_layers', 4)
+        kwargs.setdefault('num_layers', 3)
         kwargs.setdefault('bidirectional', True)
         kwargs.setdefault('dropout', 0.2)
         self.lstm = LSTM(inputdim, output_size, **kwargs)
         init_rnn(self.lstm)
-        self.attn = SimpleAttention(kwargs['hidden_size']*2)
+        self.attn = SimpleAttention(kwargs['hidden_size'] * 2)
 
     def forward(self, x):
         """Forwards input vector through network
 
-        :x: TODO
+        :x: input tensor of shape (B, T, D) [Batch, Time, Dim]
         :returns: TODO
 
         """
         x, _ = self.lstm.net(x)
-        x = self.attn(x)[0].unsqueeze(1)
+        x = self.attn(x)[0]
         return self.lstm.outputlayer(x)
 
     def extract_feature(self, x):
@@ -154,10 +215,8 @@ def extract_feature(self, x):
 
 
 class SimpleAttention(nn.Module):
-
     """Docstring for SimpleAttention. """
-
-    def __init__(self, inputdim):
+    def __init__(self, inputdim, outputdim=1):
         """TODO: to be defined1.
 
         :inputdim: TODO
@@ -166,10 +225,10 @@ def __init__(self, inputdim):
         nn.Module.__init__(self)
 
         self._inputdim = inputdim
-        self.attn = nn.Linear(inputdim, 1, bias=False)
-        nn.init.xavier_uniform_(self.attn.weight)
+        self.attn = nn.Linear(inputdim, outputdim, bias=False)
+        nn.init.normal_(self.attn.weight, std=0.05)
 
     def forward(self, x):
         weights = torch.softmax(self.attn(x), dim=1)
-        out = torch.bmm(weights.transpose(1, 2), torch.tanh(x)).squeeze(0)
+        out = (weights * x).sum(dim=1).unsqueeze(1)
         return out, weights
diff --git a/requirements.txt b/requirements.txt
index 3166914..3ec294c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,18 +1,21 @@
-tqdm==4.24.0
-allennlp==0.8.2
-tableprint==0.8.0
-tabulate==0.8.2
-fire==0.1.3
-nltk==3.3
+torch==1.2.0
 kaldi_io==0.9.1
-scipy==1.2.1
-torchnet==0.0.4
-pandas==0.24.1
-numpy==1.16.2
-bert_serving_client==1.8.3
-imbalanced_learn==0.4.3
-torch==0.4.1.post2
-gensim==3.7.1
+bert_serving_server==1.9.6
+pytorch_ignite==0.2.0
+numpy==1.16.4
+librosa==0.7.0
+tabulate==0.8.3
+mistletoe==0.7.2
+scipy==1.3.0
+tqdm==4.32.2
+pandas==0.24.2
+fire==0.1.3
+imbalanced_learn==0.5.0
+allennlp==0.8.5
+gensim==3.8.0
+ignite==1.1.0
 imblearn==0.0
-scikit_learn==0.20.3
-PyYAML==5.1
+nltk==3.4.5
+plotnine==0.6.0
+scikit_learn==0.21.3
+PyYAML==5.1.2
diff --git a/run.py b/run.py
index e3448c1..11a3732 100644
--- a/run.py
+++ b/run.py
@@ -3,6 +3,7 @@
 import datetime
 import torch
 from pprint import pformat
+import glob
 import models
 from dataset import create_dataloader
 import fire
@@ -14,518 +15,402 @@
 import os
 import numpy as np
 from sklearn import metrics
-import tableprint as tp
 import sklearn.preprocessing as pre
-import torchnet as tnt
-
-
-class BinarySimilarMeter(object):
-    """Only counts ones, does not consider zeros as being correct"""
-
-    def __init__(self, sigmoid_output=False):
-        super(BinarySimilarMeter, self).__init__()
-        self.sigmoid_output = sigmoid_output
-        self.reset()
-
-    def reset(self):
-        self.correct = 0
-        self.n = 0
-
-    def add(self, output, target):
-        if self.sigmoid_output:
-            output = torch.sigmoid(output)
-        target = target.float()
-        output = output.round()
-        self.correct += np.sum(np.logical_and(output, target).numpy())
-        self.n += (target == 1).nonzero().shape[0]
-
-    def value(self):
-        if self.n == 0:
-            return 0
-        return (self.correct / self.n) * 100.
-
-
-class BinaryAccuracyMeter(object):
-    """Counts all outputs, including zero"""
-
-    def __init__(self, sigmoid_output=False):
-        super(BinaryAccuracyMeter, self).__init__()
-        self.sigmoid_output = sigmoid_output
-        self.reset()
-
-    def reset(self):
-        self.correct = 0
-        self.n = 0
+import uuid
+from tabulate import tabulate
+import sys
+from ignite.contrib.handlers import ProgressBar
+from ignite.engine import (Engine, Events)
+from ignite.handlers import EarlyStopping, ModelCheckpoint
+from ignite.metrics import Loss, RunningAverage, ConfusionMatrix, MeanAbsoluteError, Precision, Recall
+from ignite.contrib.handlers.param_scheduler import LRScheduler
+from torch.optim.lr_scheduler import StepLR
+
+device = 'cpu'
+if torch.cuda.is_available(
+) and 'SLURM_JOB_PARTITION' in os.environ and 'gpu' in os.environ[
+        'SLURM_JOB_PARTITION']:
+    device = 'cuda'
+    # Without results are slightly inconsistent
+    torch.backends.cudnn.deterministic = True
+DEVICE = torch.device(device)
+
+
+class Runner(object):
+    """docstring for Runner"""
+    def __init__(self, seed=0):
+        super(Runner, self).__init__()
+        torch.manual_seed(seed)
+        np.random.seed(seed)
+        if device == 'cuda':
+            torch.cuda.manual_seed(seed)
+
+    @staticmethod
+    def _forward(model, batch, poolingfunction):
+        inputs, targets = batch
+        inputs, targets = inputs.float().to(DEVICE), targets.float().to(DEVICE)
+        return poolingfunction(model(inputs), 1), targets
+
+    def train(self, config, **kwargs):
+        config_parameters = parse_config_or_kwargs(config, **kwargs)
+        outputdir = os.path.join(
+            config_parameters['outputpath'], config_parameters['model'],
+            "{}_{}".format(
+                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
+                uuid.uuid1().hex))
+        checkpoint_handler = ModelCheckpoint(
+            outputdir,
+            'run',
+            n_saved=1,
+            require_empty=False,
+            create_dir=True,
+            score_function=lambda engine: -engine.state.metrics['Loss'],
+            save_as_state_dict=False,
+            score_name='loss')
+
+        train_kaldi_string = parsecopyfeats(
+            config_parameters['trainfeatures'],
+            **config_parameters['feature_args'])
+        dev_kaldi_string = parsecopyfeats(config_parameters['devfeatures'],
+                                          **config_parameters['feature_args'])
+        logger = genlogger(os.path.join(outputdir, 'train.log'))
+        logger.info("Experiment is stored in {}".format(outputdir))
+        for line in pformat(config_parameters).split('\n'):
+            logger.info(line)
+        scaler = getattr(
+            pre,
+            config_parameters['scaler'])(**config_parameters['scaler_args'])
+        inputdim = -1
+        logger.info("<== Estimating Scaler ({}) ==>".format(
+            scaler.__class__.__name__))
+        for _, feat in kaldi_io.read_mat_ark(train_kaldi_string):
+            scaler.partial_fit(feat)
+            inputdim = feat.shape[-1]
+        assert inputdim > 0, "Reading inputstream failed"
+        logger.info("Features: {} Input dimension: {}".format(
+            config_parameters['trainfeatures'], inputdim))
+        logger.info("<== Labels ==>")
+        train_label_df = pd.read_csv(
+            config_parameters['trainlabels']).set_index('Participant_ID')
+        dev_label_df = pd.read_csv(
+            config_parameters['devlabels']).set_index('Participant_ID')
+        train_label_df.index = train_label_df.index.astype(str)
+        dev_label_df.index = dev_label_df.index.astype(str)
+        # target_type = ('PHQ8_Score', 'PHQ8_Binary')
+        target_type = ('PHQ8_Score', 'PHQ8_Binary')
+        n_labels = len(target_type)  # PHQ8 + Binary
+        # Scores and their respective PHQ8
+        train_labels = train_label_df.loc[:, target_type].T.apply(
+            tuple).to_dict()
+        dev_labels = dev_label_df.loc[:, target_type].T.apply(tuple).to_dict()
+        train_dataloader = create_dataloader(
+            train_kaldi_string,
+            train_labels,
+            transform=scaler.transform,
+            shuffle=True,
+            **config_parameters['dataloader_args'])
+        cv_dataloader = create_dataloader(
+            dev_kaldi_string,
+            dev_labels,
+            transform=scaler.transform,
+            shuffle=False,
+            **config_parameters['dataloader_args'])
+        model = getattr(models, config_parameters['model'])(
+            inputdim=inputdim,
+            output_size=n_labels,
+            **config_parameters['model_args'])
+        if 'pretrain' in config_parameters:
+            logger.info("Loading pretrained model {}".format(
+                config_parameters['pretrain']))
+            pretrained_model = torch.load(config_parameters['pretrain'],
+                                          map_location=lambda st, loc: st)
+            if 'Attn' in pretrained_model.__class__.__name__:
+                model.lstm.load_state_dict(pretrained_model.lstm.state_dict())
+            else:
+                model.net.load_state_dict(pretrained_model.net.state_dict())
+        logger.info("<== Model ==>")
+        for line in pformat(model).split('\n'):
+            logger.info(line)
+        criterion = getattr(
+            losses,
+            config_parameters['loss'])(**config_parameters['loss_args'])
+        optimizer = getattr(torch.optim, config_parameters['optimizer'])(
+            list(model.parameters()) + list(criterion.parameters()),
+            **config_parameters['optimizer_args'])
+        poolingfunction = parse_poolingfunction(
+            config_parameters['poolingfunction'])
+        criterion = criterion.to(device)
+        model = model.to(device)
 
-    def add(self, output, target):
-        if self.sigmoid_output:
-            output = torch.sigmoid(output)
-        output = output.float()
-        target = target.float()
-        output = output.round()
-        self.correct += int((output == target).sum())
-        self.n += np.prod(output.shape)
+        def _train_batch(_, batch):
+            model.train()
+            with torch.enable_grad():
+                optimizer.zero_grad()
+                outputs, targets = Runner._forward(model, batch,
+                                                   poolingfunction)
+                loss = criterion(outputs, targets)
+                loss.backward()
+                optimizer.step()
+                return loss.item()
+
+        def _inference(_, batch):
+            model.eval()
+            with torch.no_grad():
+                return Runner._forward(model, batch, poolingfunction)
+
+        def meter_transform(output):
+            y_pred, y = output
+            # y_pred is of shape [Bx2] (0 = MSE, 1 = BCE)
+            # y = is of shape [Bx2] (0=Mse, 1 = BCE)
+            return torch.sigmoid(y_pred[:, 1]).round(), y[:, 1].long()
+
+        precision = Precision(output_transform=meter_transform, average=False)
+        recall = Recall(output_transform=meter_transform, average=False)
+        F1 = (precision * recall * 2 / (precision + recall)).mean()
+        metrics = {
+            'Loss':
+            Loss(criterion),
+            'Recall':
+            Recall(output_transform=meter_transform, average=True),
+            'Precision':
+            Precision(output_transform=meter_transform, average=True),
+            'MAE':
+            MeanAbsoluteError(
+                output_transform=lambda out: (out[0][:, 0], out[1][:, 0])),
+            'F1':
+            F1
+        }
+
+        train_engine = Engine(_train_batch)
+        inference_engine = Engine(_inference)
+        for name, metric in metrics.items():
+            metric.attach(inference_engine, name)
+        RunningAverage(output_transform=lambda x: x).attach(
+            train_engine, 'run_loss')
+        pbar = ProgressBar(persist=False)
+        pbar.attach(train_engine, ['run_loss'])
+
+        scheduler = getattr(torch.optim.lr_scheduler,
+                            config_parameters['scheduler'])(
+                                optimizer,
+                                **config_parameters['scheduler_args'])
+        early_stop_handler = EarlyStopping(
+            patience=5,
+            score_function=lambda engine: -engine.state.metrics['Loss'],
+            trainer=train_engine)
+        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
+                                           early_stop_handler)
+        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
+                                           checkpoint_handler, {
+                                               'model': model,
+                                               'scaler': scaler,
+                                               'config': config_parameters
+                                           })
+
+        @train_engine.on(Events.EPOCH_COMPLETED)
+        def compute_metrics(engine):
+            inference_engine.run(cv_dataloader)
+            validation_string_list = [
+                "Validation Results - Epoch: {:<3}".format(engine.state.epoch)
+            ]
+            for metric in metrics:
+                validation_string_list.append("{}: {:<5.2f}".format(
+                    metric, inference_engine.state.metrics[metric]))
+            logger.info(" ".join(validation_string_list))
+
+            pbar.n = pbar.last_print_n = 0
+
+        @inference_engine.on(Events.COMPLETED)
+        def update_reduce_on_plateau(engine):
+            val_loss = engine.state.metrics['Loss']
+            if 'ReduceLROnPlateau' == scheduler.__class__.__name__:
+                scheduler.step(val_loss)
+            else:
+                scheduler.step()
+
+        train_engine.run(train_dataloader,
+                         max_epochs=config_parameters['epochs'])
+        # Return for further processing
+        return outputdir
+
+    def evaluate(self,
+                 experiment_path: str,
+                 outputfile: str = 'results.csv',
+                 **kwargs):
+        """Prints out the stats for the given model ( MAE, RMSE, F1, Pre, Rec)
+
+
+        """
+        config = torch.load(glob.glob(
+            "{}/run_config*".format(experiment_path))[0],
+                            map_location=lambda storage, loc: storage)
+        model = torch.load(glob.glob(
+            "{}/run_model*".format(experiment_path))[0],
+                           map_location=lambda storage, loc: storage)
+        scaler = torch.load(glob.glob(
+            "{}/run_scaler*".format(experiment_path))[0],
+                            map_location=lambda storage, loc: storage)
+        config_parameters = dict(config, **kwargs)
+        dev_features = config_parameters['devfeatures']
+        dev_label_df = pd.read_csv(
+            config_parameters['devlabels']).set_index('Participant_ID')
+        dev_label_df.index = dev_label_df.index.astype(str)
 
-    def value(self):
-        if self.n == 0:
-            return 0
-        return (self.correct / self.n) * 100.
+        dev_labels = dev_label_df.loc[:, ['PHQ8_Score', 'PHQ8_Binary'
+                                          ]].T.apply(tuple).to_dict()
+        outputfile = os.path.join(experiment_path, outputfile)
+        y_score_true, y_score_pred, y_binary_pred, y_binary_true = [], [], [], []
+
+        poolingfunction = parse_poolingfunction(
+            config_parameters['poolingfunction'])
+        dataloader = create_dataloader(dev_features,
+                                       dev_labels,
+                                       transform=scaler.transform,
+                                       batch_size=1,
+                                       num_workers=1,
+                                       shuffle=False)
+
+        model = model.to(device).eval()
+        with torch.no_grad():
+            for batch in dataloader:
+                output, target = Runner._forward(model, batch, poolingfunction)
+                y_score_pred.append(output[:, 0].cpu().numpy())
+                y_score_true.append(target[:, 0].cpu().numpy())
+                y_binary_pred.append(
+                    torch.sigmoid(output[:, 1]).round().cpu().numpy())
+                y_binary_true.append(target[:, 1].cpu().numpy())
+        y_score_true = np.concatenate(y_score_true)
+        y_score_pred = np.concatenate(y_score_pred)
+        y_binary_pred = np.concatenate(y_binary_pred)
+        y_binary_true = np.concatenate(y_binary_true)
+
+        with open(outputfile, 'w') as wp:
+            pre = metrics.precision_score(y_binary_true,
+                                          y_binary_pred,
+                                          average='macro')
+            rec = metrics.recall_score(y_binary_true,
+                                       y_binary_pred,
+                                       average='macro')
+            f1 = 2 * pre * rec / (pre + rec)
+            rmse = np.sqrt(
+                metrics.mean_squared_error(y_score_true, y_score_pred))
+            mae = metrics.mean_absolute_error(y_score_true, y_score_pred)
+            df = pd.DataFrame(
+                {
+                    'precision': pre,
+                    'recall': rec,
+                    'F1': f1,
+                    'MAE': mae,
+                    'RMSE': rmse
+                },
+                index=["Macro"])
+            df.to_csv(wp, index=False)
+            print(tabulate(df, headers='keys'))
+        return df
+
+    def evaluates(
+            self,
+            *experiment_paths: str,
+            outputfile: str = 'scores.csv',
+    ):
+        result_dfs = []
+        for exp_path in experiment_paths:
+            print("Evaluating {}".format(exp_path))
+            try:
+                result_df = self.evaluate(exp_path)
+                exp_config = torch.load(
+                    glob.glob("{}/run_config*".format(exp_path))[0],
+                    map_location=lambda storage, loc: storage)
+                result_df['exp'] = os.path.basename(exp_path)
+                result_df['model'] = exp_config['model']
+                result_df['optimizer'] = exp_config['optimizer']
+                result_df['batch_size'] = exp_config['dataloader_args'][
+                    'batch_size']
+                result_df['poolingfunction'] = exp_config['poolingfunction']
+                result_df['loss'] = exp_config['loss']
+                result_dfs.append(result_df)
+            except Exception as e:  #Sometimes EOFError happens
+                pass
+        df = pd.concat(result_dfs)
+        df.sort_values(by='F1', ascending=False, inplace=True)
+
+        with open(outputfile, 'w') as wp:
+            df.to_csv(wp, index=False)
+            print(tabulate(df, headers='keys', tablefmt="pipe"))
 
 
 def parsecopyfeats(feat, cmvn=False, delta=False, splice=None):
-    outstr = "copy-feats ark:{} ark:- |".format(feat)
-    if cmvn:
-        outstr += "apply-cmvn-sliding --center ark:- ark:- |"
-    if delta:
-        outstr += "add-deltas ark:- ark:- |"
-    if splice and splice > 0:
-        outstr += "splice-feats --left-context={} --right-context={} ark:- ark:- |".format(
-            splice, splice)
+    # Check if user has kaldi installed, otherwise just use kaldi_io (without extra transformations)
+    import shutil
+    if shutil.which('copy-feats') is None:
+        return feat
+    else:
+        outstr = "copy-feats ark:{} ark:- |".format(feat)
+        if cmvn:
+            outstr += "apply-cmvn-sliding --center ark:- ark:- |"
+        if delta:
+            outstr += "add-deltas ark:- ark:- |"
+        if splice and splice > 0:
+            outstr += "splice-feats --left-context={} --right-context={} ark:- ark:- |".format(
+                splice, splice)
     return outstr
 
 
-def runepoch(dataloader, model, criterion, optimizer=None, dotrain=True, poolfun=lambda x, d: x.mean(d)):
-    model = model.train() if dotrain else model.eval()
-    # By default use average pooling
-    loss_meter = tnt.meter.AverageValueMeter()
-    acc_meter = BinaryAccuracyMeter(sigmoid_output=True)
-    with torch.set_grad_enabled(dotrain):
-        for i, (features, targets) in enumerate(dataloader):
-            features = features.float().to(device)
-            targets = targets.to(device)
-            outputs = model(features)
-            outputs = poolfun(outputs, 1)
-            loss = criterion(outputs, targets).cpu()
-            loss_meter.add(loss.item())
-            acc_meter.add(outputs.cpu().data[:, 1], targets.cpu().data[:, 1])
-            if dotrain:
-                optimizer.zero_grad()
-                loss.backward()
-                optimizer.step()
-
-    return loss_meter.value(), acc_meter.value()
-
-
-def genlogger(outdir, fname):
+def genlogger(outputfile):
     formatter = logging.Formatter(
         "[ %(levelname)s : %(asctime)s ] - %(message)s")
-    logging.basicConfig(
-        level=logging.DEBUG,
-        format="[ %(levelname)s : %(asctime)s ] - %(message)s")
-    logger = logging.getLogger("Pyobj, f")
-    # Dump log to file
-    fh = logging.FileHandler(os.path.join(outdir, fname))
-    fh.setFormatter(formatter)
-    logger.addHandler(fh)
+    logger = logging.getLogger(__name__ + "." + outputfile)
+    logger.setLevel(logging.INFO)
+    stdlog = logging.StreamHandler(sys.stdout)
+    stdlog.setFormatter(formatter)
+    file_handler = logging.FileHandler(outputfile)
+    file_handler.setFormatter(formatter)
+    # Log to stdout
+    logger.addHandler(file_handler)
+    logger.addHandler(stdlog)
     return logger
 
 
 def parse_config_or_kwargs(config_file, **kwargs):
     with open(config_file) as con_read:
-        yaml_config = yaml.load(con_read)
+        yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)
     # passed kwargs will override yaml config
-    for key in kwargs.keys():
-        assert key in yaml_config, "Parameter {} invalid!".format(key)
+    # for key in kwargs.keys():
+    # assert key in yaml_config, "Parameter {} invalid!".format(key)
     return dict(yaml_config, **kwargs)
 
 
-def criterion_improver(mode):
-    """Returns a function to ascertain if criterion did improve
-
-    :mode: can be ether 'loss' or 'acc'
-    :returns: function that can be called, function returns true if criterion improved
-
-    """
-    assert mode in ('loss', 'acc')
-    best_value = np.inf if mode == 'loss' else 0
-
-    def comparator(x, best_x):
-        return x < best_x if mode == 'loss' else x > best_x
-
-    def inner(x):
-        # rebind parent scope variable
-        nonlocal best_value
-        if comparator(x, best_value):
-            best_value = x
-            return True
-        return False
-    return inner
-
-
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-torch.manual_seed(0)
-np.random.seed(0)
-
-if device == 'cuda':
-    torch.cuda.manual_seed_all(0)
-
-
-def train(config='config/text_lstm_deep.yaml', **kwargs):
-    """Trains a model on the given features and vocab.
-
-    :features: str: Input features. Needs to be kaldi formatted file
-    :config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE
-    :returns: None
-    """
-
-    config_parameters = parse_config_or_kwargs(config, **kwargs)
-    outputdir = os.path.join(
-        config_parameters['outputpath'],
-        config_parameters['model'],
-        datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%f'))
-    try:
-        os.makedirs(outputdir)
-    except IOError:
-        pass
-    logger = genlogger(outputdir, 'train.log')
-    logger.info("Storing data at: {}".format(outputdir))
-    logger.info("<== Passed Arguments ==>")
-    # Print arguments into logs
-    for line in pformat(config_parameters).split('\n'):
-        logger.info(line)
-
-    train_kaldi_string = parsecopyfeats(
-        config_parameters['trainfeatures'], **config_parameters['feature_args'])
-    dev_kaldi_string = parsecopyfeats(
-        config_parameters['devfeatures'], **config_parameters['feature_args'])
-
-    scaler = getattr(
-        pre, config_parameters['scaler'])(
-        **config_parameters['scaler_args'])
-    inputdim = -1
-    logger.info(
-        "<== Estimating Scaler ({}) ==>".format(
-            scaler.__class__.__name__))
-    for kid, feat in kaldi_io.read_mat_ark(train_kaldi_string):
-        scaler.partial_fit(feat)
-        inputdim = feat.shape[-1]
-    assert inputdim > 0, "Reading inputstream failed"
-    logger.info(
-        "Features: {} Input dimension: {}".format(
-            config_parameters['trainfeatures'],
-            inputdim))
-    logger.info("<== Labels ==>")
-    train_label_df = pd.read_csv(
-        config_parameters['trainlabels']).set_index('Participant_ID')
-    dev_label_df = pd.read_csv(
-        config_parameters['devlabels']).set_index('Participant_ID')
-    train_label_df.index = train_label_df.index.astype(str)
-    dev_label_df.index = dev_label_df.index.astype(str)
-
-    target_type = ('PHQ8_Score', 'PHQ8_Binary')
-
-    # Scores and their respective
-    train_labels = train_label_df.loc[:, target_type].T.apply(tuple).to_dict()
-    dev_labels = dev_label_df.loc[:, target_type].T.apply(tuple).to_dict()
-    n_labels = len(target_type)
-
-    train_dataloader = create_dataloader(
-        train_kaldi_string,
-        train_labels,
-        transform=scaler.transform,
-        **config_parameters['dataloader_args'])
-    cv_dataloader = create_dataloader(
-        dev_kaldi_string,
-        dev_labels,
-        transform=scaler.transform,
-        **config_parameters['dataloader_args'])
-    model = getattr(
-        models,
-        config_parameters['model'])(
-        inputdim=inputdim,
-        output_size=n_labels,
-        **config_parameters['model_args'])
-    logger.info("<== Model ==>")
-    for line in pformat(model).split('\n'):
-        logger.info(line)
-    model = model.to(device)
-    optimizer = getattr(
-        torch.optim, config_parameters['optimizer'])(
-        model.parameters(),
-        **config_parameters['optimizer_args'])
-
-    scheduler = getattr(
-        torch.optim.lr_scheduler,
-        config_parameters['scheduler'])(
-        optimizer,
-        **config_parameters['scheduler_args'])
-    criterion = getattr(losses, config_parameters['loss'])(
-        **config_parameters['loss_args'])
-    criterion.to(device)
-
-    trainedmodelpath = os.path.join(outputdir, 'model.th')
-
-    criterion_improved = criterion_improver(
-        config_parameters['improvecriterion'])
-    header = [
-        'Epoch',
-        'Loss(T)',
-        'Loss(CV)',
-        "Acc(T)",
-        "Acc(CV)",
-    ]
-    for line in tp.header(
-        header,
-            style='grid').split('\n'):
-        logger.info(line)
-
-    poolingfunction_name = config_parameters['poolingfunction']
-    pooling_function = parse_poolingfunction(poolingfunction_name)
-    for epoch in range(1, config_parameters['epochs']+1):
-        train_utt_loss_mean_std, train_utt_acc = runepoch(
-            train_dataloader, model, criterion, optimizer, dotrain=True, poolfun=pooling_function)
-        cv_utt_loss_mean_std, cv_utt_acc = runepoch(
-            cv_dataloader, model,  criterion, dotrain=False, poolfun=pooling_function)
-        logger.info(
-            tp.row(
-                (epoch,) +
-                (train_utt_loss_mean_std[0],
-                 cv_utt_loss_mean_std[0],
-                 train_utt_acc, cv_utt_acc),
-                style='grid'))
-        epoch_meanloss = cv_utt_loss_mean_std[0]
-        if epoch % config_parameters['saveinterval'] == 0:
-            torch.save({'model': model,
-                        'scaler': scaler,
-                        # 'encoder': many_hot_encoder,
-                        'config': config_parameters},
-                       os.path.join(outputdir, 'model_{}.th'.format(epoch)))
-        # ReduceOnPlateau needs a value to work
-        schedarg = epoch_meanloss if scheduler.__class__.__name__ == 'ReduceLROnPlateau' else None
-        scheduler.step(schedarg)
-        if criterion_improved(epoch_meanloss):
-            torch.save({'model': model,
-                        'scaler': scaler,
-                        # 'encoder': many_hot_encoder,
-                        'config': config_parameters},
-                       trainedmodelpath)
-        if optimizer.param_groups[0]['lr'] < 1e-7:
-            break
-    logger.info(tp.bottom(len(header), style='grid'))
-    logger.info("Results are in: {}".format(outputdir))
-    return outputdir
-
-
 def parse_poolingfunction(poolingfunction_name='mean'):
     if poolingfunction_name == 'mean':
-        def pooling_function(x, d): return x.mean(d)
+
+        def pooling_function(x, d):
+            return x.mean(d)
     elif poolingfunction_name == 'max':
-        def pooling_function(x, d): return x.max(d)[0]
+
+        def pooling_function(x, d):
+            return x.max(d)[0]
     elif poolingfunction_name == 'linear':
-        def pooling_function(x, d): return (x**2).sum(d) / x.sum(d)
+
+        def pooling_function(x, d):
+            return (x**2).sum(d) / x.sum(d)
     elif poolingfunction_name == 'exp':
-        def pooling_function(x, d): return (
-            x.exp() * x).sum(d) / x.exp().sum(d)
-    elif poolingfunction_name == 'time':  # Last timestep
-        def pooling_function(x, d): return x.select(d, -1)
-    elif poolingfunction_name == 'first':
-        def pooling_function(x, d): return x.select(d, 0)
 
-    return pooling_function
+        def pooling_function(x, d):
+            return (x.exp() * x).sum(d) / x.exp().sum(d)
+    elif poolingfunction_name == 'last':  # Last timestep
 
+        def pooling_function(x, d):
+            return x.select(d, -1)
+    elif poolingfunction_name == 'first':
 
-def _extract_features_from_model(model, features, scaler=None):
-    if model.__class__.__name__ == 'LSTM':
-        fwdmodel = torch.nn.Sequential(model.net)
-    elif model.__class__.__name__ == 'LSTMSimpleAttn':
-        fwdmodel = torch.nn.Sequential(model)
-    elif model.__class__.__name__ == 'TCN':
-        fwdmodel = None
+        def pooling_function(x, d):
+            return x.select(d, 0)
     else:
-        assert False, "Model not prepared for extraction"
-    ret = {}
-    with torch.no_grad():
-        model = model.to(device)
-        for k, v in kaldi_io.read_mat_ark(features):
-            if scaler:
-                v = scaler.transform(v)
-            v = torch.from_numpy(v).to(device).unsqueeze(0)
-            out = fwdmodel(v)
-            if isinstance(out, tuple):  # LSTM output, 2 values hidden,and x
-                out = out[0]
-            ret[k] = out.cpu().squeeze().numpy()
-    return ret
-
-
-def extract_features(model_path: str, features='trainfeatures'):
-    modeldump = torch.load(model_path, lambda storage, loc: storage)
-    model_dir = os.path.dirname(model_path)
-    config_parameters = modeldump['config']
-    dev_features = config_parameters[features]
-    scaler = modeldump['scaler']
-    model = modeldump['model']
-
-    outputfile = os.path.join(model_dir, features + '.ark')
-    dev_features = parsecopyfeats(
-        dev_features, **config_parameters['feature_args'])
-
-    vectors = _extract_features_from_model(model, dev_features, scaler)
-    with open(outputfile, 'wb') as wp:
-        for key, vector in vectors.items():
-            kaldi_io.write_mat(wp, vector, key=key)
-    return outputfile
-
-
-def stats(model_path: str, outputfile: str = 'stats.txt', cutoff: int = None):
-    """Prints out the stats for the given model ( MAE, RMSE, F1, Pre, Rec)
-
-    :model_path:str: TODO
-    :returns: TODO
-
-    """
-    from tabulate import tabulate
-    modeldump = torch.load(model_path, lambda storage, loc: storage)
-    model_dir = os.path.dirname(model_path)
-    config_parameters = modeldump['config']
-    dev_features = config_parameters['devfeatures']
-    dev_label_df = pd.read_csv(
-        config_parameters['devlabels']).set_index('Participant_ID')
-    dev_label_df.index = dev_label_df.index.astype(str)
-
-    dev_labels = dev_label_df.loc[:, [
-        'PHQ8_Score', 'PHQ8_Binary']].T.apply(tuple).to_dict()
-    outputfile = os.path.join(model_dir, outputfile)
-    y_score_true, y_score_pred, y_binary_pred, y_binary_true = [], [], [], []
-    scores = _forward_model(model_path, dev_features, cutoff=cutoff)
-    for key, score in scores.items():
-        score_pred, binary_pred = torch.chunk(score, 2, dim=-1)
-        y_score_pred.append(score_pred.numpy())
-        y_score_true.append(dev_labels[key][0])
-        y_binary_pred.append(torch.sigmoid(
-            binary_pred).round().numpy().astype(int).item())
-        y_binary_true.append(dev_labels[key][1])
-
-    with open(outputfile, 'w') as wp:
-        pre = metrics.precision_score(
-            y_binary_true, y_binary_pred, average='macro')
-        rec = metrics.recall_score(
-            y_binary_true, y_binary_pred, average='macro')
-        f1 = 2*pre*rec / (pre+rec)
-        rmse = np.sqrt(metrics.mean_squared_error(y_score_true, y_score_pred))
-        mae = metrics.mean_absolute_error(y_score_true, y_score_pred)
-        df = pd.DataFrame(
-            {'precision': pre, 'recall': rec, 'F1': f1, 'MAE': mae, 'RMSE': rmse}, index=["Macro"])
-        print(tabulate(df, headers='keys'), file=wp)
-        print(tabulate(df, headers='keys'))
-
-
-def fuse(model_paths: list, outputfile='scores.txt', cutoff: int = None):
-    from tabulate import tabulate
-    scores = []
-    for model_path in model_paths:
-        modeldump = torch.load(model_path, lambda storage, loc: storage)
-        config_parameters = modeldump['config']
-        dev_features = config_parameters['devfeatures']
-        dev_label_df = pd.read_csv(
-            config_parameters['devlabels']).set_index('Participant_ID')
-        dev_label_df.index = dev_label_df.index.astype(str)
-        score = _forward_model(model_path, dev_features, cutoff=cutoff)
-        for speaker, pred_score in score.items():
-            scores.append({
-                'speaker': speaker,
-                'MAE': float(pred_score[0].numpy()),
-                'binary': float(torch.sigmoid(pred_score[1]).numpy()),
-                'model': model_path,
-                'binary_true': dev_label_df.loc[speaker, 'PHQ8_Binary'],
-                'MAE_true': dev_label_df.loc[speaker, 'PHQ8_Score']
-            })
-    df = pd.DataFrame(scores)
-
-    spkmeans = df.groupby('speaker')[['MAE', 'MAE_true', 'binary', 'binary_true']].mean()
-    spkmeans['binary'] = spkmeans['binary'] > 0.5
-
-    with open(outputfile, 'w') as wp:
-        pre = metrics.precision_score(
-            spkmeans['binary_true'].values, spkmeans['binary'].values, average='macro')
-        rec = metrics.recall_score(
-            spkmeans['binary_true'].values, spkmeans['binary'].values, average='macro')
-        f1 = 2*pre*rec / (pre+rec)
-        rmse = np.sqrt(metrics.mean_squared_error(
-            spkmeans['MAE_true'].values, spkmeans['MAE'].values))
-        mae = metrics.mean_absolute_error(
-            spkmeans['MAE_true'].values, spkmeans['MAE'].values)
-        df = pd.DataFrame(
-            {'precision': pre, 'recall': rec, 'F1': f1, 'MAE': mae, 'RMSE': rmse}, index=["Macro"])
-        print(tabulate(df, headers='keys'), file=wp)
-        print(tabulate(df, headers='keys'))
-
-
-def _forward_model(model_path: str, features: str, dopooling: bool = True, cutoff=None):
-    modeldump = torch.load(model_path, lambda storage, loc: storage)
-    scaler = modeldump['scaler']
-    config_parameters = modeldump['config']
-    pooling_function = parse_poolingfunction(
-        config_parameters['poolingfunction'])
-    kaldi_string = parsecopyfeats(
-        features, **config_parameters['feature_args'])
-    ret = {}
-
-    with torch.no_grad():
-        model = modeldump['model'].to(device).eval()
-        for key, feat in kaldi_io.read_mat_ark(kaldi_string):
-            feat = scaler.transform(feat)
-            if cutoff:
-                # Cut all after cutoff
-                feat = feat[:cutoff]
-            feat = torch.from_numpy(feat).to(device).unsqueeze(0)
-            output = model(feat).cpu()
-            if dopooling:
-                output = pooling_function(output, 1).squeeze(0)
-            ret[key] = output
-    return ret
-
-
-def trainstats(config: str = 'config/text_lstm_deep.yaml', **kwargs):
-    """Runs training and then prints dev stats
-
-    :config:str: config file
-    :**kwargs: Extra overwrite configs
-    :returns: None
-
-    """
-    output_model = train(config, **kwargs)
-    best_model = os.path.join(output_model, 'model.th')
-    stats(best_model)
-
-
-def run_search(config: str = 'config/text_lstm_deep.yaml', lr=0.1, mom=0.9, nest=False, **kwargs):
-    """Runs training and then prints dev stats
-
-    :config:str: config file
-    :**kwargs: Extra overwrite configs
-    :returns: None
-
-    """
-    optimizer_args = {'lr': lr, 'momentum': mom, 'nesterov': nest}
-    kwargs['optimizer_args'] = optimizer_args
-    output_model = train(config, **kwargs)
-    best_model = os.path.join(output_model, 'model.th')
-    stats(best_model)
-
-
-def run_search_adam(config: str = 'config/text_lstm_deep.yaml', lr=0.1, **kwargs):
-    """Runs training and then prints dev stats
-
-    :config:str: config file
-    :**kwargs: Extra overwrite configs
-    :returns: None
-
-    """
-    optimizer_args = {'lr': lr}
-    kwargs['optimizer_args'] = optimizer_args
-    output_model = train(config, **kwargs)
-    best_model = os.path.join(output_model, 'model.th')
-    stats(best_model)
+        raise ValueError(
+            "Pooling function {} not available".format(poolingfunction_name))
+
+    return pooling_function
 
 
 if __name__ == '__main__':
-    fire.Fire({
-        'train': train,
-        'stats': stats,
-        'trainstats': trainstats,
-        'search': run_search,
-        'searchadam': run_search_adam,
-        'ex': extract_features,
-        'fwd': _forward_model,
-        'fuse': fuse,
-    })
+    fire.Fire(Runner)
diff --git a/show_most_common_words.py b/show_most_common_words.py
index 4506ee0..b228231 100644
--- a/show_most_common_words.py
+++ b/show_most_common_words.py
@@ -7,8 +7,7 @@
 from tabulate import tabulate
 from scipy.signal import find_peaks
 import kaldi_io
-from glob import glob
-
+import glob
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
@@ -17,16 +16,17 @@ def read_transcripts(root: str, speakers: list, filterlen: int):
     transcripts = {}
     for speaker in speakers:
         # PRocess transcript first to get start_end
-        transcript_file = glob(os.path.join(
-            root, str(speaker)) + '*TRANSCRIPT.csv')[0]
+        transcript_file = glob.glob(
+            os.path.join(root, str(speaker)) + '*TRANSCRIPT.csv')[0]
         transcript_df = pd.read_csv(transcript_file, sep='\t')
         transcript_df.dropna(inplace=True)
         transcript_df.value = transcript_df.value.str.strip()
         # Subset for
-        transcript_df = transcript_df[transcript_df.value.str.split().apply(
-            len) > filterlen]
-        transcript_df = transcript_df[transcript_df.speaker ==
-                                      'Participant'].reset_index().loc[:, 'value'].to_dict()
+        transcript_df = transcript_df[
+            transcript_df.value.str.split().apply(len) > filterlen]
+        transcript_df = transcript_df[
+            transcript_df.speaker ==
+            'Participant'].reset_index().loc[:, 'value'].to_dict()
         transcripts[speaker] = transcript_df
     return transcripts
 
@@ -35,58 +35,68 @@ def forward_model(model, x, pooling):
     model = model.eval()
     model.to(device)
     if 'LSTMAttn' in model.__class__.__name__:
-        lx = torch.sigmoid(model.local_embed(x))
         x, _ = model.lstm.net(x)
-        x = lx*x
         x, attention_weights = model.attn(x)
         return model.lstm.outputlayer(x), attention_weights
     elif 'GRUAttn' in model.__class__.__name__:
         x, _ = model.net(x)
         x, attention_weights = model.attn(x)
         return model.outputlayer(x), attention_weights
-    elif 'LSTMSimpleAttn' in model.__class__.__name__:
-        x, _ = model.lstm.net(x)
-        x, attention_weights = model.attn(x)
-        return model.lstm.outputlayer(x), attention_weights
     else:
         outputs = model(x)
         return pooling(outputs, 1), outputs[:, :, 0]
 
 
-def show_most_common(model_path: str, max_heigth: float = 0.8, filterlen: int = 0):
-    modeldump = torch.load(model_path, lambda storage, loc: storage)
-    scaler = modeldump['scaler']
-    config_parameters = modeldump['config']
-    data = config_parameters['devfeatures']
-    dev_labels = config_parameters['devlabels']
-    model = modeldump['model']
+def show_most_common(experiment_path: str,
+                     max_heigth: float = 0.8,
+                     filterlen: int = 0,
+                     shift:int = 0):
+
+    config = torch.load(glob.glob("{}/run_config*".format(experiment_path))[0],
+                        map_location=lambda storage, loc: storage)
+    model = torch.load(glob.glob("{}/run_model*".format(experiment_path))[0],
+                       map_location=lambda storage, loc: storage)
+    scaler = torch.load(glob.glob("{}/run_scaler*".format(experiment_path))[0],
+                        map_location=lambda storage, loc: storage)
+    data = config['devfeatures']
+    dev_labels = config['devlabels']
     dev_labels_binary = pd.read_csv(dev_labels)
-    poolingfunction = parse_poolingfunction(
-        config_parameters['poolingfunction'])
-    TRANSCRIPT_ROOT = '/mnt/lustre/sjtu/users/hedi7/depression/woz/system/data_preprocess/labels_processed/'
-    transcripts = read_transcripts(
-        TRANSCRIPT_ROOT, dev_labels_binary['Participant_ID'].values, filterlen)
+    poolingfunction = parse_poolingfunction(config['poolingfunction'])
+    TRANSCRIPT_ROOT = '../data_preprocess/labels_processed/'
+    transcripts = read_transcripts(TRANSCRIPT_ROOT,
+                                   dev_labels_binary['Participant_ID'].values,
+                                   filterlen)
     #
     all_words = []
     with torch.no_grad():
         for k, v in kaldi_io.read_mat_ark(data):
             v = scaler.transform(v)
             cur_transcripts = transcripts[int(k)]
+            v = np.roll(v, shift, axis=0) # Shifting by n sentences
             v = torch.from_numpy(v).unsqueeze(0).float().to(device)
             output, weights = forward_model(model, v, poolingfunction)
-            output, weights = output.squeeze().cpu(), weights.squeeze().cpu().numpy()
+            output, weights = output.squeeze().cpu(), weights.squeeze().cpu(
+            ).numpy()
             model_thinks_depression = float(torch.sigmoid(output[1]).numpy())
-            peaks = find_peaks(weights, height=np.max(weights)*max_heigth)[0]
-            assert len(cur_transcripts) == len(weights), "Trans: {} Weight: {}".format(
-                len(cur_transcripts), len(weights))
+            peaks = find_peaks(weights, height=np.max(weights) * max_heigth)[0]
+            assert len(cur_transcripts) == len(
+                weights), "Trans: {} Weight: {}".format(
+                    len(cur_transcripts), len(weights))
             for peak in peaks:
                 all_words.append({
-                    'sent': cur_transcripts[peak],
-                    'label': bool(dev_labels_binary[dev_labels_binary['Participant_ID'] == int(k)].PHQ8_Binary.values[0]),
-                    'predprob': model_thinks_depression,
-                    'pred': model_thinks_depression > 0.5
+                    'sent':
+                    cur_transcripts[peak],
+                    'label':
+                    bool(dev_labels_binary[dev_labels_binary['Participant_ID']
+                                           == int(k)].PHQ8_Binary.values[0]),
+                    'predprob':
+                    model_thinks_depression,
+                    'pred':
+                    model_thinks_depression > 0.5
                 })
 
+    if not all_words:
+        raise ValueError("Nothing found for the specified peak limit, maybe lower --max_height ?")
     df = pd.DataFrame(all_words)
     # aggregate = df.groupby('sent', as_index=False).agg(
     # ['count', 'mean'])