From d76c8d86be8f4712c0769f16878a6a39131f46a5 Mon Sep 17 00:00:00 2001 From: Zhijie Cao Date: Wed, 27 Mar 2019 20:31:40 +0800 Subject: [PATCH] Fix an inference issue under 'memory' mode. --- Cell_BLAST/directi.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/Cell_BLAST/directi.py b/Cell_BLAST/directi.py index 7867c8d..560ea17 100644 --- a/Cell_BLAST/directi.py +++ b/Cell_BLAST/directi.py @@ -431,7 +431,7 @@ def inference(self, x, batch_size=4096, noisy=0, progress_bar=False, priority = "memory" if x.shape[0] > 1e5 else "speed" if priority == "speed": xrep = [x] * noisy - if scipy.sparse.issparse(x[0]): + if scipy.sparse.issparse(x): xrep = scipy.sparse.vstack(xrep) else: xrep = np.vstack(xrep) @@ -440,8 +440,9 @@ def inference(self, x, batch_size=4096, noisy=0, progress_bar=False, return np.stack(np.split(lrep, noisy), axis=1) else: # priority == "memory": return np.stack([self._fetch( - self.latent, x, batch_size, True, progress_bar, random_seed - ) for _ in range(noisy)], axis=1) + self.latent, x, batch_size, True, progress_bar, + (random_seed + i) if random_seed is not None else None + ) for i in range(noisy)], axis=1) return self._fetch(self.latent, x, batch_size, False, progress_bar) def clustering(self, x, batch_size=4096, progress_bar=True):