From e49e4e9ec2d094a283a8b7b1afdecedb3e23144f Mon Sep 17 00:00:00 2001 From: Robbie McCorkell Date: Fri, 5 Apr 2024 14:32:51 +0100 Subject: [PATCH] make cleanup function public --- README.md | 8 ++++++++ attribution/attribution.py | 20 ++++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 6e7ee1b..126da51 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,14 @@ attributor = Attributor( ) ``` +### Cleaning Up + +A convenience method is provided to clean up memory used by Python and Torch. This can be useful when running the library in a cloud notebook environment: + +```python +attributor.cleanup() +``` + ## Development To contribute to the library, you will need to install the development requirements: diff --git a/attribution/attribution.py b/attribution/attribution.py index e9cf246..8f49238 100644 --- a/attribution/attribution.py +++ b/attribution/attribution.py @@ -80,8 +80,6 @@ def get_attributions( attr_scores[it] = attr_scores_next_token token_ids = torch.cat((token_ids, next_token_id.view(-1)), dim=0) - self._cleanup() - return attr_scores, token_ids def print_attributions( @@ -106,6 +104,20 @@ def print_attributions( word_list, attr_scores, token_ids, generation_length ) + def cleanup(self) -> None: + """ + This function is used to free up the memory resources. It clears the GPU cache and triggers garbage collection. + + Returns: + None + """ + if hasattr(torch, self.device) and hasattr( + getattr(torch, self.device), "empty_cache" + ): + logging.info(f"Clearing {self.device} cache") + getattr(torch, self.device).empty_cache() + gc.collect() + def _get_input_embeddings( self, embeddings: torch.Tensor, token_ids: torch.Tensor ) -> torch.Tensor: @@ -152,10 +164,6 @@ def _get_attr_scores_next_token( attr_scores_next_token[i] = presence_grad return attr_scores_next_token - def _cleanup(self) -> None: - torch.cuda.empty_cache() - gc.collect() - def _validate_inputs( self, model: transformers.GPT2LMHeadModel,