Skip to content

Commit

Permalink
Optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
anotherbugmaster committed Jun 16, 2018
1 parent 8e647a1 commit 89cc803
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions gensim/models/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

logger = logging.getLogger(__name__)

import time


class Nmf(interfaces.TransformationABC, basemodel.BaseTopicModel):
"""Online Non-Negative Matrix Factorization.
Expand All @@ -25,6 +23,7 @@ def __init__(
passes=1,
lambda_=1.,
kappa=1.,
use_r=False,
store_r=False,
w_max_iter=200,
w_stop_condition=1e-4,
Expand Down Expand Up @@ -56,13 +55,16 @@ def __init__(
Whether to save residuals during training
normalize
"""
self._w_error = None
self._h_r_error = None
self.n_features = None
self.num_topics = num_topics
self.id2word = id2word
self.chunksize = chunksize
self.passes = passes
self._lambda_ = lambda_
self._kappa = kappa
self.use_r = use_r
self._w_max_iter = w_max_iter
self._w_stop_condition = w_stop_condition
self._h_r_max_iter = h_r_max_iter
Expand Down Expand Up @@ -96,13 +98,13 @@ def B(self, value):
self._B = value

def get_topics(self):
# if self.normalize:
# return self.__softmax(self._W.T, axis=1)
if self.normalize:
return self._W.T / self._W.T.sum(axis=1).reshape(-1, 1)

return self._W.T

def __getitem__(self, bow, eps=None):
return self.get_document_topics(bow, eps)
return self.get_document_topics(bow, eps)[0]

def show_topics(self, num_topics=10, num_words=10, log=False, formatted=True):
"""
Expand Down Expand Up @@ -178,11 +180,6 @@ def get_topic_terms(self, topicid, topn=10):
bestn = matutils.argsort(topic, topn, reverse=True)
return [(idx, topic[idx]) for idx in bestn]

@staticmethod
def __softmax(matrix, axis):
exp_matrix = np.exp(matrix - matrix.max(axis=axis))
return exp_matrix / exp_matrix.sum(axis=axis)

def get_term_topics(self, word_id, minimum_probability=None):
"""
Args:
Expand Down Expand Up @@ -212,15 +209,15 @@ def get_term_topics(self, word_id, minimum_probability=None):

def get_document_topics(self, bow, minimum_probability=None):
v = matutils.corpus2dense([bow], len(self.id2word), 1)
h, _ = self._solveproj(v, self._W, v_max=np.inf)
h, r = self._solveproj(v, self._W, v_max=np.inf)

if self.normalize:
h = self.__softmax(h, axis=0)
h /= h.sum(axis=0)

if minimum_probability is not None:
h[h < minimum_probability] = 0

return h
return h, r

def _setup(self, corpus):
self._h, self._r = None, None
Expand Down Expand Up @@ -282,28 +279,36 @@ def update(self, corpus, chunks_as_numpy=False):

chunk_idx += 1

logger.info(
"Loss (no outliers): {}\tLoss (with outliers): {}".format(
np.linalg.norm(v - self._W.dot(h)),
np.linalg.norm(v - self._W.dot(h) - r),
)
)

self._r = r
self._h = h

def _solve_w(self):
eta = self._kappa / np.linalg.norm(self.A, "fro")
error = None

if not self._w_error:
self._w_error = self.__w_error()

for iter_number in range(self._w_max_iter):
# logger.info("w_error: %s" % self._w_error)

self._W -= eta * (np.dot(self._W, self.A) - self.B)
self.__transform()

if iter_number == self._w_max_iter - 1:
break

error_ = self.__w_error()

if error and np.abs(error_ - error) < np.abs(
error * self._w_stop_condition
if np.abs(error_ - self._w_error) < np.abs(
self._w_error * self._w_stop_condition
):
break

error = error_
self._w_error = error_

def __w_error(self):
return 0.5 * np.trace(self._W.T.dot(self._W.dot(self.A) - self.B))
Expand Down Expand Up @@ -348,27 +353,28 @@ def _solveproj(self, v, W, h=None, r=None, v_max=None):

# eta = self._kappa / np.linalg.norm(W, 'fro') ** 2

error = None
if not self._h_r_error:
self._h_r_error = self.__h_r_error(v, h, r)

for iter_number in range(self._h_r_max_iter):
# logger.info("h_r_error: %s" % self._h_r_error)

Wt_v_minus_r = W.T.dot(v - r)

solve_h(h, Wt_v_minus_r, WtW, self._kappa)

r_actual = v - W.dot(h)

solve_r(r, r_actual, self._lambda_, self.v_max)

if iter_number == self._h_r_max_iter - 1:
break
if self.use_r:
r_actual = v - W.dot(h)
solve_r(r, r_actual, self._lambda_, self.v_max)

error_ = self.__h_r_error(v, h, r)

if error and np.abs(error - error_) < np.abs(
error * self._h_r_stop_condition
if np.abs(self._h_r_error - error_) < np.abs(
self._h_r_error * self._h_r_stop_condition
):
break

error = error_
self._h_r_error = error_


return h, r

0 comments on commit 89cc803

Please sign in to comment.