Skip to content

Commit

Permalink
Cancel GMMHMM training on floating point errors
Browse files Browse the repository at this point in the history
  • Loading branch information
HelmerNylen committed Dec 10, 2020
1 parent 3bd9ac9 commit 2bdca44
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
1 change: 1 addition & 0 deletions classifier/confusion_table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

# Might be good to replace this with the sklearn confusion table
class ConfusionTable:
"""Confusion table keeping track of testing metrics"""
TOTAL = ...
Expand Down
65 changes: 65 additions & 0 deletions classifier/model_gmmhmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from hmmlearn import hmm
from .model import Model

from sklearn.utils import check_array
from hmmlearn.utils import iter_from_X_lengths

import logging
_log = logging.getLogger(hmm.__name__)

Expand Down Expand Up @@ -201,3 +204,65 @@ def _do_mstep(self, stats):

else:
super()._do_mstep(stats)

# pylint: disable=redefined-builtin
def fit(self, X, lengths=None):
"""Estimate model parameters.
An initialization step is performed before entering the
EM algorithm. If you want to avoid this step for a subset of
the parameters, pass proper ``init_params`` keyword argument
to estimator's constructor.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Feature matrix of individual samples.
lengths : array-like of integers, shape (n_sequences, )
Lengths of the individual sequences in ``X``. The sum of
these should be ``n_samples``.
Returns
-------
self : object
Returns self.
"""
X = check_array(X)
self._init(X, lengths=lengths)
self._check()

self.monitor_._reset()
for iter in range(self.n_iter):
stats = self._initialize_sufficient_statistics()
curr_logprob = 0
for i, j in iter_from_X_lengths(X, lengths):
framelogprob = self._compute_log_likelihood(X[i:j])
logprob, fwdlattice = self._do_forward_pass(framelogprob)
curr_logprob += logprob
bwdlattice = self._do_backward_pass(framelogprob)
posteriors = self._compute_posteriors(fwdlattice, bwdlattice)
try:
with np.errstate(invalid="raise"):
self._accumulate_sufficient_statistics(
stats, X[i:j], framelogprob, posteriors, fwdlattice,
bwdlattice)
except FloatingPointError as e:
print(f"{type(e).__name__}: {e}")
print("Divergence detected, stopping training")
return self


# XXX must be before convergence check, because otherwise
# there won't be any updates for the case ``n_iter=1``.
self._do_mstep(stats)

self.monitor_.report(curr_logprob)
if self.monitor_.converged:
break

if (self.transmat_.sum(axis=1) == 0).any():
_log.warning("Some rows of transmat_ have zero sum because no "
"transition from the state was ever observed.")

return self

0 comments on commit 2bdca44

Please sign in to comment.