Skip to content

Commit

Permalink
make cleanup function public
Browse files Browse the repository at this point in the history
  • Loading branch information
robbiemccorkell committed Apr 5, 2024
1 parent eaf0602 commit e49e4e9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ attributor = Attributor(
)
```

### Cleaning Up

A convenience method is provided to clean up memory used by Python and Torch. This can be useful when running the library in a cloud notebook environment:

```python
attributor.cleanup()
```

## Development

To contribute to the library, you will need to install the development requirements:
Expand Down
20 changes: 14 additions & 6 deletions attribution/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ def get_attributions(
attr_scores[it] = attr_scores_next_token
token_ids = torch.cat((token_ids, next_token_id.view(-1)), dim=0)

self._cleanup()

return attr_scores, token_ids

def print_attributions(
Expand All @@ -106,6 +104,20 @@ def print_attributions(
word_list, attr_scores, token_ids, generation_length
)

def cleanup(self) -> None:
"""
This function is used to free up the memory resources. It clears the GPU cache and triggers garbage collection.
Returns:
None
"""
if hasattr(torch, self.device) and hasattr(
getattr(torch, self.device), "empty_cache"
):
logging.info(f"Clearing {self.device} cache")
getattr(torch, self.device).empty_cache()
gc.collect()

def _get_input_embeddings(
self, embeddings: torch.Tensor, token_ids: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -152,10 +164,6 @@ def _get_attr_scores_next_token(
attr_scores_next_token[i] = presence_grad
return attr_scores_next_token

def _cleanup(self) -> None:
torch.cuda.empty_cache()
gc.collect()

def _validate_inputs(
self,
model: transformers.GPT2LMHeadModel,
Expand Down

0 comments on commit e49e4e9

Please sign in to comment.