Skip to content

Commit

Permalink
Fix bug when printing attr matrix for cosine method, and parametrize …
Browse files Browse the repository at this point in the history
…if input shoud be perturbed by token or by word
  • Loading branch information
Sebastian Sosa committed May 7, 2024
1 parent b2d7cbf commit 0bb8b2e
Show file tree
Hide file tree
Showing 3 changed files with 841 additions and 1,805 deletions.
26 changes: 16 additions & 10 deletions attribution/attribution_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,34 +109,40 @@ def cosine_similarity_attribution(
tokenizer: PreTrainedTokenizer,
) -> Tuple[float, np.ndarray]:
# Extract embeddings
initial_sentence_emb, initial_token_embs = get_sentence_embeddings(
initial_output_sentence_emb, initial_output_token_embs = get_sentence_embeddings(
original_output_choice.message.content, model, tokenizer
)
perturbed_sentence_emb, perturbed_token_embs = get_sentence_embeddings(
perturbed_output_choice.message.content, model, tokenizer
perturbed_output_sentence_emb, perturbed_output_token_embs = (
get_sentence_embeddings(
perturbed_output_choice.message.content, model, tokenizer
)
)

# Reshape embeddings
initial_sentence_emb = initial_sentence_emb.reshape(1, -1)
perturbed_sentence_emb = perturbed_sentence_emb.reshape(1, -1)
initial_output_sentence_emb = initial_output_sentence_emb.reshape(1, -1)
perturbed_output_sentence_emb = perturbed_output_sentence_emb.reshape(1, -1)

# Calculate similarities
self_similarity = float(
cosine_similarity(initial_sentence_emb, initial_sentence_emb)
cosine_similarity(initial_output_sentence_emb, initial_output_sentence_emb)
)
sentence_similarity = float(
cosine_similarity(initial_sentence_emb, perturbed_sentence_emb)
cosine_similarity(initial_output_sentence_emb, perturbed_output_sentence_emb)
)

# Calculate token similarities for shared length
shared_length = min(initial_token_embs.shape[0], perturbed_token_embs.shape[0])
shared_length = min(
initial_output_token_embs.shape[0], perturbed_output_token_embs.shape[0]
)
token_similarities_shared = cosine_similarity(
initial_token_embs[:shared_length], perturbed_token_embs[:shared_length]
initial_output_token_embs[:shared_length],
perturbed_output_token_embs[:shared_length],
).diagonal()

# Pad token similarities to match initial token embeddings shape
token_similarities = np.pad(
token_similarities_shared, (0, initial_token_embs.shape[0] - shared_length)
token_similarities_shared,
(0, initial_output_token_embs.shape[0] - shared_length),
)

# Return difference in sentence similarity and token similarities
Expand Down
12 changes: 11 additions & 1 deletion attribution/experiment_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, experiment_id=0):
"original_input",
"original_output",
"perturbation_strategy",
"perturb_word_wise",
"duration",
]
)
Expand Down Expand Up @@ -51,7 +52,11 @@ def __init__(self, experiment_id=0):
)

def start_experiment(
self, original_input: str, original_output: str, perturbation_strategy: str
self,
original_input: str,
original_output: str,
perturbation_strategy: str,
perturb_word_wise: bool,
):
self.experiment_id += 1
self.experiment_start_time = time.time()
Expand All @@ -60,6 +65,7 @@ def start_experiment(
"original_input": original_input,
"original_output": original_output,
"perturbation_strategy": perturbation_strategy,
"perturb_word_wise": perturb_word_wise,
"duration": None,
}

Expand Down Expand Up @@ -140,11 +146,15 @@ def print_sentence_attribution(self):
perturbation_strategy = self.df_experiments.loc[
self.df_experiments["exp_id"] == exp_id, "perturbation_strategy"
].values[0]
perturb_word_wise = self.df_experiments.loc[
self.df_experiments["exp_id"] == exp_id, "perturb_word_wise"
].values[0]

sentence_data = {
"exp_id": exp_id,
"attribution_strategy": attr_strat,
"perturbation_strategy": perturbation_strategy,
"perturb_word_wise": perturb_word_wise,
}
sentence_data.update(
{f"token_{i+1}": token_attr for i, token_attr in enumerate(token_attrs)}
Expand Down
Loading

0 comments on commit 0bb8b2e

Please sign in to comment.