Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patched several memory leaks in the mask module. #2

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/seismicrna/core/batch/accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ def accumulate(batches: Iterable[MutsBatch],
fits_per_read_per_batch[-1].loc[:, column] = fpr.values
if info_per_read_per_batch is not None:
info_per_read_per_batch[-1].loc[:, column] = ipr.values

# Clear batch cache.
batch.clear_cache()
# Clear pattern caches.
for pattern in patterns.values():
pattern.clear_cache()

def get_data_per_read(data_per_read_per_batch: pd.DataFrame | None):
if data_per_read_per_batch is not None:
if data_per_read_per_batch:
Expand Down
9 changes: 9 additions & 0 deletions src/seismicrna/core/batch/muts.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ def nonprox_muts(self, refseq: DNA, pattern: RelPattern, min_gap: int):
# which means that they have proximal mutations.
nonprox &= read_counts <= 1
return self.read_nums[nonprox]

def clear_cache(self):
for method in [self.pos_index, self.count_base_types,
self.coverage_matrix, self.cover_per_pos,
self.cover_per_read, self.rels_per_pos,
self.rels_per_read, self.reads_per_pos,
self.count_per_pos, self.count_per_read,
self.nonprox_muts]:
method.cache_clear()

def iter_reads(self, refseq: DNA, pattern: RelPattern):
""" Yield the 5'/3' end/middle positions and the positions that
Expand Down
4 changes: 4 additions & 0 deletions src/seismicrna/core/io/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ def load_batch(self, batch: int):
return self.get_data_type().load(self.get_batch_path(batch),
self.report_checksum(batch))

def clear_cache(self):
self.get_batch_path.cache_clear()
self.report_checksum.cache_clear()

def _iter_batches(self):
for batch in self.batch_nums:
yield self.load_batch(batch)
Expand Down
9 changes: 9 additions & 0 deletions src/seismicrna/core/rel/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def intersect(self, other: HalfRelPattern):
""" Intersect the HalfRelPattern with another. """
return self.__class__(*(set(self.codes) & set(other.codes)))

def clear_cache(self):
self.fits.cache_clear()

def __str__(self):
return f"{type(self).__name__} {self.to_report_format()}"

Expand Down Expand Up @@ -302,6 +305,12 @@ def intersect(self, other: RelPattern | None, invert: bool = False):
nos = self.nos
return self.__class__(nos, yes) if invert else self.__class__(yes, nos)

def clear_cache(self):
self.fits.cache_clear()
self.intersect.cache_clear()
self.yes.clear_cache()
self.nos.clear_cache()

def __str__(self):
return f"{type(self).__name__} ++ {self.yes} -- {self.nos}"

Expand Down
6 changes: 5 additions & 1 deletion src/seismicrna/core/seq/section.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ def add_mask(self,
p = np.setdiff1d(self.range_int, p, assume_unique=True)
# Record the positions that have not already been masked.
self._masks[name] = np.setdiff1d(p, self.masked_int, assume_unique=True)
logger.debug(f"Added mask {repr(name)} ({self._masks[name]}) to {self}")
# Do not log self._masks[name] due to memory leak.
logger.debug(f"Added mask {repr(name)} to {self}")

def _find_gu(self) -> np.ndarray:
""" Array of each position whose base is neither A nor C. """
Expand Down Expand Up @@ -465,6 +466,9 @@ def subsection(self,
and name is None)
else name))

def clear_cache(self):
self.seq.clear_cache()

def __str__(self):
return f"Section {self.ref_sect} ({self.hyphen}) {self.mask_names}"

Expand Down
3 changes: 3 additions & 0 deletions src/seismicrna/core/seq/xna.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def compress(self):
""" Compress the sequence. """
return CompressedSeq(self)

def clear_cache(self):
self.to_array.cache_clear()

def __str__(self):
return self._seq

Expand Down
9 changes: 7 additions & 2 deletions src/seismicrna/mask/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,13 @@ def create_report(self, began: datetime, ended: datetime):
min_ninfo_pos=self.min_ninfo_pos,
max_fmut_pos=self.max_fmut_pos,
n_pos_init=self.section.length,
n_pos_gu=self.pos_gu.size,
n_pos_gu=self.pos_gu.size if self.exclude_gu else 0, # Not sure how to handle this properly.
n_pos_polya=self.pos_polya.size,
n_pos_user=self.pos_user.size,
n_pos_min_ninfo=self.pos_min_ninfo.size,
n_pos_max_fmut=self.pos_max_fmut.size,
n_pos_kept=self.pos_kept.size,
pos_gu=self.pos_gu,
pos_gu=self.pos_gu if self.exclude_gu else np.array([], dtype=int), # Not sure how to handle this properly.
pos_polya=self.pos_polya,
pos_user=self.pos_user,
pos_min_ninfo=self.pos_min_ninfo,
Expand Down Expand Up @@ -361,6 +361,11 @@ def mask_section(dataset: RelateLoader,
ended = datetime.now()
report = masker.create_report(began, ended)
report.save(dataset.top, overwrite=True)
# Clear section cache.
section.clear_cache()
# Clear dataset caches.
dataset.clear_cache()

else:
logger.warning(f"File exists: {report_file}")
return report_file
Expand Down