Skip to content

Commit

Permalink
Experimenting with loss
Browse files Browse the repository at this point in the history
  • Loading branch information
anotherbugmaster committed Jun 16, 2018
1 parent 89cc803 commit bbd3099
Showing 1 changed file with 11 additions and 18 deletions.
29 changes: 11 additions & 18 deletions gensim/models/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,23 +251,22 @@ def update(self, corpus, chunks_as_numpy=False):
if self.n_features is None:
corpus = self._setup(corpus)

r, h = self._r, self._h

chunk_idx = 1

for _ in range(self.passes):
for chunk in utils.grouper(
corpus, self.chunksize, as_numpy=chunks_as_numpy
):
v = matutils.corpus2dense(chunk, len(self.id2word), len(chunk))
h, r = self._solveproj(v, self._W, r=r, h=h, v_max=self.v_max)
self._h, self._r = self._solveproj(v, self._W, r=self._r, h=self._h, v_max=self.v_max)
h, r = self._h, self._r
self._H.append(h)
if self._R is not None:
self._R.append(r)

self.A += np.dot(h, h.T)
self.B += np.dot((v - r), h.T)
self._solve_w()
self._solve_w(v)

if chunk_idx % self.eval_every == 0:
logger.info(
Expand All @@ -286,22 +285,19 @@ def update(self, corpus, chunks_as_numpy=False):
)
)

self._r = r
self._h = h

def _solve_w(self):
def _solve_w(self, v):
eta = self._kappa / np.linalg.norm(self.A, "fro")

if not self._w_error:
self._w_error = self.__w_error()
self._w_error = self.__error(v, self._h, self._r)

for iter_number in range(self._w_max_iter):
# logger.info("w_error: %s" % self._w_error)
logger.info("w_error: %s" % self._w_error)

self._W -= eta * (np.dot(self._W, self.A) - self.B)
self.__transform()

error_ = self.__w_error()
error_ = self.__error(v, self._h, self._r)

if np.abs(error_ - self._w_error) < np.abs(
self._w_error * self._w_stop_condition
Expand All @@ -310,10 +306,7 @@ def _solve_w(self):

self._w_error = error_

def __w_error(self):
return 0.5 * np.trace(self._W.T.dot(self._W.dot(self.A) - self.B))

def __h_r_error(self, v, h, r):
def __error(self, v, h, r):
return 0.5 * np.linalg.norm(
v - self._W.dot(h) - r, "fro"
) ** 2 + self._lambda_ * np.linalg.norm(r, 1)
Expand Down Expand Up @@ -354,10 +347,10 @@ def _solveproj(self, v, W, h=None, r=None, v_max=None):
# eta = self._kappa / np.linalg.norm(W, 'fro') ** 2

if not self._h_r_error:
self._h_r_error = self.__h_r_error(v, h, r)
self._h_r_error = self.__error(v, h, r)

for iter_number in range(self._h_r_max_iter):
# logger.info("h_r_error: %s" % self._h_r_error)
logger.info("h_r_error: %s" % self._h_r_error)

Wt_v_minus_r = W.T.dot(v - r)

Expand All @@ -367,7 +360,7 @@ def _solveproj(self, v, W, h=None, r=None, v_max=None):
r_actual = v - W.dot(h)
solve_r(r, r_actual, self._lambda_, self.v_max)

error_ = self.__h_r_error(v, h, r)
error_ = self.__error(v, h, r)

if np.abs(self._h_r_error - error_) < np.abs(
self._h_r_error * self._h_r_stop_condition
Expand Down

0 comments on commit bbd3099

Please sign in to comment.