diff --git a/gensim/models/nmf.py b/gensim/models/nmf.py index 601fd093c2..774ba4dfe4 100644 --- a/gensim/models/nmf.py +++ b/gensim/models/nmf.py @@ -251,8 +251,6 @@ def update(self, corpus, chunks_as_numpy=False): if self.n_features is None: corpus = self._setup(corpus) - r, h = self._r, self._h - chunk_idx = 1 for _ in range(self.passes): @@ -260,14 +258,15 @@ def update(self, corpus, chunks_as_numpy=False): corpus, self.chunksize, as_numpy=chunks_as_numpy ): v = matutils.corpus2dense(chunk, len(self.id2word), len(chunk)) - h, r = self._solveproj(v, self._W, r=r, h=h, v_max=self.v_max) + self._h, self._r = self._solveproj(v, self._W, r=self._r, h=self._h, v_max=self.v_max) + h, r = self._h, self._r self._H.append(h) if self._R is not None: self._R.append(r) self.A += np.dot(h, h.T) self.B += np.dot((v - r), h.T) - self._solve_w() + self._solve_w(v) if chunk_idx % self.eval_every == 0: logger.info( @@ -286,22 +285,19 @@ def update(self, corpus, chunks_as_numpy=False): ) ) - self._r = r - self._h = h - - def _solve_w(self): + def _solve_w(self, v): eta = self._kappa / np.linalg.norm(self.A, "fro") if not self._w_error: - self._w_error = self.__w_error() + self._w_error = self.__error(v, self._h, self._r) for iter_number in range(self._w_max_iter): - # logger.info("w_error: %s" % self._w_error) + logger.info("w_error: %s" % self._w_error) self._W -= eta * (np.dot(self._W, self.A) - self.B) self.__transform() - error_ = self.__w_error() + error_ = self.__error(v, self._h, self._r) if np.abs(error_ - self._w_error) < np.abs( self._w_error * self._w_stop_condition @@ -310,10 +306,7 @@ def _solve_w(self): self._w_error = error_ - def __w_error(self): - return 0.5 * np.trace(self._W.T.dot(self._W.dot(self.A) - self.B)) - - def __h_r_error(self, v, h, r): + def __error(self, v, h, r): return 0.5 * np.linalg.norm( v - self._W.dot(h) - r, "fro" ) ** 2 + self._lambda_ * np.linalg.norm(r, 1) @@ -354,10 +347,10 @@ def _solveproj(self, v, W, h=None, r=None, v_max=None): # eta = self._kappa / np.linalg.norm(W, 'fro') ** 2 if not self._h_r_error: - self._h_r_error = self.__h_r_error(v, h, r) + self._h_r_error = self.__error(v, h, r) for iter_number in range(self._h_r_max_iter): - # logger.info("h_r_error: %s" % self._h_r_error) + logger.info("h_r_error: %s" % self._h_r_error) Wt_v_minus_r = W.T.dot(v - r) @@ -367,7 +360,7 @@ def _solveproj(self, v, W, h=None, r=None, v_max=None): r_actual = v - W.dot(h) solve_r(r, r_actual, self._lambda_, self.v_max) - error_ = self.__h_r_error(v, h, r) + error_ = self.__error(v, h, r) if np.abs(self._h_r_error - error_) < np.abs( self._h_r_error * self._h_r_stop_condition