Skip to content

Commit

Permalink
dealing with whitespace tokens in nth nearest perturbation
Browse files Browse the repository at this point in the history
  • Loading branch information
Acusick1 committed Jul 25, 2024
1 parent 1f36de7 commit 07f52e8
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions attribution/token_perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,17 @@ def get_replacement_units(self, units_to_replace: list[Unit]) -> list[Unit]:
for unit in units_to_replace:
replacement_tokens = []
for token in unit:
token_id = self.tokenizer.encode(token, add_special_tokens=False)[0]
# Stripping whitespace token if present as it often results in a completely different replacement token
stripped_token = token.strip("Ġ")
token_id = self.tokenizer.encode(stripped_token, add_special_tokens=False)[0]
replacement_token_id = self.get_replacement_token(token_id)
replacement_token = self.tokenizer._convert_id_to_token(replacement_token_id)
replacement_tokens.append(f"Ġ{replacement_token}")

# Re-add whitespace prefix if necessary
if token.startswith("Ġ") and not replacement_token.startswith("Ġ"):
replacement_token = f"Ġ{replacement_token}"

replacement_tokens.append(replacement_token)

replacement_units.append(replacement_tokens)

Expand Down

0 comments on commit 07f52e8

Please sign in to comment.