diff --git a/src/seismicrna/core/batch/accum.py b/src/seismicrna/core/batch/accum.py index 53f3a384..cb33bdd4 100644 --- a/src/seismicrna/core/batch/accum.py +++ b/src/seismicrna/core/batch/accum.py @@ -70,7 +70,12 @@ def accumulate(batches: Iterable[MutsBatch], fits_per_read_per_batch[-1].loc[:, column] = fpr.values if info_per_read_per_batch is not None: info_per_read_per_batch[-1].loc[:, column] = ipr.values - + # Clear batch cache. + batch.clear_cache() + # Clear pattern caches. + for pattern in patterns.values(): + pattern.clear_cache() + def get_data_per_read(data_per_read_per_batch: pd.DataFrame | None): if data_per_read_per_batch is not None: if data_per_read_per_batch: diff --git a/src/seismicrna/core/batch/muts.py b/src/seismicrna/core/batch/muts.py index 6aa5d55a..aefaeee8 100644 --- a/src/seismicrna/core/batch/muts.py +++ b/src/seismicrna/core/batch/muts.py @@ -192,6 +192,15 @@ def nonprox_muts(self, refseq: DNA, pattern: RelPattern, min_gap: int): # which means that they have proximal mutations. nonprox &= read_counts <= 1 return self.read_nums[nonprox] + + def clear_cache(self): + for method in [self.pos_index, self.count_base_types, + self.coverage_matrix, self.cover_per_pos, + self.cover_per_read, self.rels_per_pos, + self.rels_per_read, self.reads_per_pos, + self.count_per_pos, self.count_per_read, + self.nonprox_muts]: + method.cache_clear() def iter_reads(self, refseq: DNA, pattern: RelPattern): """ Yield the 5'/3' end/middle positions and the positions that diff --git a/src/seismicrna/core/io/data.py b/src/seismicrna/core/io/data.py index 52eb030e..bda878d3 100644 --- a/src/seismicrna/core/io/data.py +++ b/src/seismicrna/core/io/data.py @@ -271,6 +271,10 @@ def load_batch(self, batch: int): return self.get_data_type().load(self.get_batch_path(batch), self.report_checksum(batch)) + def clear_cache(self): + self.get_batch_path.cache_clear() + self.report_checksum.cache_clear() + def _iter_batches(self): for batch in self.batch_nums: yield self.load_batch(batch) diff --git a/src/seismicrna/core/rel/pattern.py b/src/seismicrna/core/rel/pattern.py index c3c232e5..5cf2a679 100644 --- a/src/seismicrna/core/rel/pattern.py +++ b/src/seismicrna/core/rel/pattern.py @@ -248,6 +248,9 @@ def intersect(self, other: HalfRelPattern): """ Intersect the HalfRelPattern with another. """ return self.__class__(*(set(self.codes) & set(other.codes))) + def clear_cache(self): + self.fits.cache_clear() + def __str__(self): return f"{type(self).__name__} {self.to_report_format()}" @@ -302,6 +305,12 @@ def intersect(self, other: RelPattern | None, invert: bool = False): nos = self.nos return self.__class__(nos, yes) if invert else self.__class__(yes, nos) + def clear_cache(self): + self.fits.cache_clear() + self.intersect.cache_clear() + self.yes.clear_cache() + self.nos.clear_cache() + def __str__(self): return f"{type(self).__name__} ++ {self.yes} -- {self.nos}" diff --git a/src/seismicrna/core/seq/section.py b/src/seismicrna/core/seq/section.py index b7bdfb08..73ce5bde 100644 --- a/src/seismicrna/core/seq/section.py +++ b/src/seismicrna/core/seq/section.py @@ -410,7 +410,8 @@ def add_mask(self, p = np.setdiff1d(self.range_int, p, assume_unique=True) # Record the positions that have not already been masked. self._masks[name] = np.setdiff1d(p, self.masked_int, assume_unique=True) - logger.debug(f"Added mask {repr(name)} ({self._masks[name]}) to {self}") + # Do not log self._masks[name] due to memory leak. + logger.debug(f"Added mask {repr(name)} to {self}") def _find_gu(self) -> np.ndarray: """ Array of each position whose base is neither A nor C. """ @@ -465,6 +466,9 @@ def subsection(self, and name is None) else name)) + def clear_cache(self): + self.seq.clear_cache() + def __str__(self): return f"Section {self.ref_sect} ({self.hyphen}) {self.mask_names}" diff --git a/src/seismicrna/core/seq/xna.py b/src/seismicrna/core/seq/xna.py index fe2ec37b..3b3d00dc 100644 --- a/src/seismicrna/core/seq/xna.py +++ b/src/seismicrna/core/seq/xna.py @@ -160,6 +160,9 @@ def compress(self): """ Compress the sequence. """ return CompressedSeq(self) + def clear_cache(self): + self.to_array.cache_clear() + def __str__(self): return self._seq diff --git a/src/seismicrna/mask/write.py b/src/seismicrna/mask/write.py index 3b5b1e71..2a5955b9 100644 --- a/src/seismicrna/mask/write.py +++ b/src/seismicrna/mask/write.py @@ -312,13 +312,13 @@ def create_report(self, began: datetime, ended: datetime): min_ninfo_pos=self.min_ninfo_pos, max_fmut_pos=self.max_fmut_pos, n_pos_init=self.section.length, - n_pos_gu=self.pos_gu.size, + n_pos_gu=self.pos_gu.size if self.exclude_gu else 0, # Not sure how to handle this properly. n_pos_polya=self.pos_polya.size, n_pos_user=self.pos_user.size, n_pos_min_ninfo=self.pos_min_ninfo.size, n_pos_max_fmut=self.pos_max_fmut.size, n_pos_kept=self.pos_kept.size, - pos_gu=self.pos_gu, + pos_gu=self.pos_gu if self.exclude_gu else np.array([], dtype=int), # Not sure how to handle this properly. pos_polya=self.pos_polya, pos_user=self.pos_user, pos_min_ninfo=self.pos_min_ninfo, @@ -361,6 +361,11 @@ def mask_section(dataset: RelateLoader, ended = datetime.now() report = masker.create_report(began, ended) report.save(dataset.top, overwrite=True) + # Clear section cache. + section.clear_cache() + # Clear dataset caches. + dataset.clear_cache() + else: logger.warning(f"File exists: {report_file}") return report_file