From 0bb8b2e732ec82e199fde89c2a4f1c057979724a Mon Sep 17 00:00:00 2001 From: Sebastian Sosa Date: Tue, 7 May 2024 16:38:16 -0300 Subject: [PATCH] Fix bug when printing attr matrix for cosine method, and parametrize if input shoud be perturbed by token or by word --- attribution/attribution_metrics.py | 26 +- attribution/experiment_logger.py | 12 +- research/api_llm_attribution.ipynb | 2608 +++++++++------------------- 3 files changed, 841 insertions(+), 1805 deletions(-) diff --git a/attribution/attribution_metrics.py b/attribution/attribution_metrics.py index 2344094..b42eebf 100644 --- a/attribution/attribution_metrics.py +++ b/attribution/attribution_metrics.py @@ -109,34 +109,40 @@ def cosine_similarity_attribution( tokenizer: PreTrainedTokenizer, ) -> Tuple[float, np.ndarray]: # Extract embeddings - initial_sentence_emb, initial_token_embs = get_sentence_embeddings( + initial_output_sentence_emb, initial_output_token_embs = get_sentence_embeddings( original_output_choice.message.content, model, tokenizer ) - perturbed_sentence_emb, perturbed_token_embs = get_sentence_embeddings( - perturbed_output_choice.message.content, model, tokenizer + perturbed_output_sentence_emb, perturbed_output_token_embs = ( + get_sentence_embeddings( + perturbed_output_choice.message.content, model, tokenizer + ) ) # Reshape embeddings - initial_sentence_emb = initial_sentence_emb.reshape(1, -1) - perturbed_sentence_emb = perturbed_sentence_emb.reshape(1, -1) + initial_output_sentence_emb = initial_output_sentence_emb.reshape(1, -1) + perturbed_output_sentence_emb = perturbed_output_sentence_emb.reshape(1, -1) # Calculate similarities self_similarity = float( - cosine_similarity(initial_sentence_emb, initial_sentence_emb) + cosine_similarity(initial_output_sentence_emb, initial_output_sentence_emb) ) sentence_similarity = float( - cosine_similarity(initial_sentence_emb, perturbed_sentence_emb) + cosine_similarity(initial_output_sentence_emb, perturbed_output_sentence_emb) ) # Calculate token similarities for shared length - shared_length = min(initial_token_embs.shape[0], perturbed_token_embs.shape[0]) + shared_length = min( + initial_output_token_embs.shape[0], perturbed_output_token_embs.shape[0] + ) token_similarities_shared = cosine_similarity( - initial_token_embs[:shared_length], perturbed_token_embs[:shared_length] + initial_output_token_embs[:shared_length], + perturbed_output_token_embs[:shared_length], ).diagonal() # Pad token similarities to match initial token embeddings shape token_similarities = np.pad( - token_similarities_shared, (0, initial_token_embs.shape[0] - shared_length) + token_similarities_shared, + (0, initial_output_token_embs.shape[0] - shared_length), ) # Return difference in sentence similarity and token similarities diff --git a/attribution/experiment_logger.py b/attribution/experiment_logger.py index d04b2ac..1265338 100644 --- a/attribution/experiment_logger.py +++ b/attribution/experiment_logger.py @@ -15,6 +15,7 @@ def __init__(self, experiment_id=0): "original_input", "original_output", "perturbation_strategy", + "perturb_word_wise", "duration", ] ) @@ -51,7 +52,11 @@ def __init__(self, experiment_id=0): ) def start_experiment( - self, original_input: str, original_output: str, perturbation_strategy: str + self, + original_input: str, + original_output: str, + perturbation_strategy: str, + perturb_word_wise: bool, ): self.experiment_id += 1 self.experiment_start_time = time.time() @@ -60,6 +65,7 @@ def start_experiment( "original_input": original_input, "original_output": original_output, "perturbation_strategy": perturbation_strategy, + "perturb_word_wise": perturb_word_wise, "duration": None, } @@ -140,11 +146,15 @@ def print_sentence_attribution(self): perturbation_strategy = self.df_experiments.loc[ self.df_experiments["exp_id"] == exp_id, "perturbation_strategy" ].values[0] + perturb_word_wise = self.df_experiments.loc[ + self.df_experiments["exp_id"] == exp_id, "perturb_word_wise" + ].values[0] sentence_data = { "exp_id": exp_id, "attribution_strategy": attr_strat, "perturbation_strategy": perturbation_strategy, + "perturb_word_wise": perturb_word_wise, } sentence_data.update( {f"token_{i+1}": token_attr for i, token_attr in enumerate(token_attrs)} diff --git a/research/api_llm_attribution.ipynb b/research/api_llm_attribution.ipynb index ca10f98..91ff9ba 100644 --- a/research/api_llm_attribution.ipynb +++ b/research/api_llm_attribution.ipynb @@ -2,9 +2,18 @@ "cells": [ { "cell_type": "code", - "execution_count": 8, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/sebastian/Projects/llm-attribution/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import os\n", "import timeit\n", @@ -41,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -51,12 +60,12 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model = GPT2LMHeadModel.from_pretrained(\"gpt2\") # or any other checkpoint\n", - "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", + "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\", add_prefix_space=True)\n", "\n", "word_token_embeddings = model.transformer.wte.weight.detach().numpy()\n", "position_embeddings = model.transformer.wpe.weight.detach().numpy()\n", @@ -67,7 +76,59 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenizing the entire input text:\n", + "[['Ġthe'], ['Ġcamel'], ['Ġcase'], ['Ġomnip'], ['otent']]\n", + "[[262], [41021], [1339], [40046], [33715]]\n", + "\n", + "Tokenizing word by word:\n", + "[['Ġthe'], ['Ġcamel'], ['Ġcase'], ['Ġomnip', 'otent']]\n", + "[[262], [41021], [1339], [40046, 33715]]\n", + "\n", + "[262, 41021, 1339]\n", + "[[262], [41021], [1339]]\n", + "[262, 41021, 1339]\n" + ] + } + ], + "source": [ + "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\", add_prefix_space=True)\n", + "\n", + "# Test input text\n", + "input_text = \"the camel case omnipotent\"\n", + "\n", + "# Tokenize the entire input text\n", + "input_tokens = [[token] for token in tokenizer.tokenize(input_text)]\n", + "input_token_ids = [\n", + " [token_id] for token_id in tokenizer.encode(input_text, add_special_tokens=False)\n", + "]\n", + "print(\"Tokenizing the entire input text:\")\n", + "print(input_tokens)\n", + "print(input_token_ids)\n", + "\n", + "# Tokenize word by word\n", + "words = input_text.split()\n", + "word_tokens = [tokenizer.tokenize(word) for word in words]\n", + "word_token_ids = [tokenizer.encode(word, add_special_tokens=False) for word in words]\n", + "print(\"\\nTokenizing word by word:\")\n", + "print(word_tokens)\n", + "print(word_token_ids)\n", + "\n", + "print()\n", + "print([token_id for sublist in input_token_ids[:3] for token_id in sublist])\n", + "print([sublist for sublist in input_token_ids[:3]])\n", + "print([token_id for sublist in input_token_ids[:3] for token_id in sublist])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -83,8 +144,8 @@ " return response.choices[0]\n", "\n", "\n", - "def calculate_token_importance_in_sequence(\n", - " input_sequence: str,\n", + "def calculate_token_importance(\n", + " input_text: str,\n", " model: PreTrainedModel,\n", " tokenizer: PreTrainedTokenizer,\n", " perturbation_strategy: str = \"fixed\",\n", @@ -94,55 +155,81 @@ " \"token_displacement\",\n", " ],\n", " logger: ExperimentLogger = None,\n", + " perturb_word_wise: bool = False,\n", "):\n", " timestamp = time()\n", - " tokens = tokenizer.tokenize(input_sequence)\n", - " token_ids = tokenizer.encode(input_sequence, add_special_tokens=False)\n", - " original_output = get_model_output(input_sequence)\n", + " original_output = get_model_output(input_text)\n", " print(f\"Chat Completion - Original: {round(time() - timestamp, 2)}s\")\n", "\n", " if logger:\n", " logger.start_experiment(\n", - " input_sequence, original_output.message.content, perturbation_strategy\n", + " input_text,\n", + " original_output.message.content,\n", + " perturbation_strategy,\n", + " perturb_word_wise,\n", " )\n", "\n", " exp_timestamp = time()\n", - " words = input_sequence.split()\n", - " token_index = 0\n", - " for word in words:\n", - " word_tokens = tokenizer.tokenize(word)\n", - " word_token_ids = tokenizer.encode(word, add_special_tokens=False)\n", + "\n", + " # A unit is either a word or a single token\n", + " unit_offset = 0\n", + " if perturb_word_wise:\n", + " words = input_text.split()\n", + " tokens_per_unit = [tokenizer.tokenize(word) for word in words]\n", + " token_ids_per_unit = [\n", + " tokenizer.encode(word, add_special_tokens=False) for word in words\n", + " ]\n", + " else:\n", + " tokens_per_unit = [[token] for token in tokenizer.tokenize(input_text)]\n", + " token_ids_per_unit = [\n", + " [token_id]\n", + " for token_id in tokenizer.encode(input_text, add_special_tokens=False)\n", + " ]\n", + "\n", + " for i_unit, unit_tokens in enumerate(tokens_per_unit):\n", " start_word_time = time()\n", " replacement_token_ids = [\n", " get_replacement_token(\n", - " word_token_ids[i],\n", + " token_id,\n", " perturbation_strategy,\n", " word_token_embeddings,\n", " tokenizer,\n", " )\n", - " for i in range(len(word_tokens))\n", + " for token_id in token_ids_per_unit[i_unit]\n", " ]\n", " print(\n", - " f\"\\nReplaced word '{word}': {round(time() - start_word_time, 2)}s - get_replacement_token()\"\n", + " f\"\\nReplaced word '{''.join(unit_tokens)}': {round(time() - start_word_time, 2)}s - get_replacement_token()\"\n", " )\n", "\n", " # Replace the current word with the new tokens\n", + " left_token_ids = [\n", + " token_id\n", + " for unit_token_ids in token_ids_per_unit[:i_unit]\n", + " for token_id in unit_token_ids\n", + " ]\n", + " right_token_ids = [\n", + " token_id\n", + " for unit_token_ids in token_ids_per_unit[i_unit + 1 :]\n", + " for token_id in unit_token_ids\n", + " ]\n", " perturbed_input = tokenizer.decode(\n", - " token_ids[:token_index]\n", - " + replacement_token_ids\n", - " + token_ids[token_index + len(word_tokens) :]\n", + " left_token_ids + replacement_token_ids + right_token_ids\n", " )\n", "\n", " # Get the output logprobs for the perturbed input\n", " timestamp = time()\n", - " print('Original: ',input_sequence)\n", - " print('Perturbed: ',perturbed_input)\n", + " print(\"Original: \", input_text)\n", + " print(\"Perturbed: \", perturbed_input)\n", " perturbed_output = get_model_output(perturbed_input)\n", " print(f\"Chat Completion - Perturbed: {round(time() - timestamp, 2)}s\")\n", "\n", " timestamp = time()\n", " for attribution_strategy in attribution_strategies:\n", - " attributed_tokens = tokens\n", + " attributed_tokens = [\n", + " token_logprob.token\n", + " for token_logprob in original_output.logprobs.content\n", + " ]\n", + " print(attribution_strategy, \"attributed_tokens\", attributed_tokens)\n", " if attribution_strategy == \"cosine\":\n", " cosine_timestamp = time()\n", " sentence_attr, token_attributions = cosine_similarity_attribution(\n", @@ -172,35 +259,35 @@ "\n", " if logger:\n", " start_logging = time()\n", - " for i in range(len(word_tokens)):\n", + " for i, unit_token in enumerate(unit_tokens):\n", " logger.log_input_token_attribution(\n", " attribution_strategy,\n", - " token_index + i,\n", - " word_tokens[i],\n", + " unit_offset + i,\n", + " unit_token,\n", " float(sentence_attr),\n", " )\n", " for j, attr_score in enumerate(token_attributions):\n", " logger.log_token_attribution_matrix(\n", " attribution_strategy,\n", - " token_index + i,\n", + " unit_offset + i,\n", " j,\n", " attributed_tokens[j],\n", " attr_score.squeeze(),\n", " )\n", " end_logging = time()\n", " time_all_attrs = time() - timestamp\n", - " print(f\"Attributions computation: {time_all_attrs}s\")\n", - " print(f\"- Cosine Attr: {round(cosine_timestamp_end - cosine_timestamp, 2)}s\")\n", - " print(\n", - " f\"- Prob Diff Attr: {round(prob_diff_timestamp_end - prob_diff_timestamp, 2)}s\"\n", - " )\n", - " print(\n", - " f\"- Token Displacement Attr: {round(token_displacement_timestamp_end - token_displacement_timestamp, 2)}s\"\n", - " )\n", - " print(f\"- Attr Logging: {round(end_logging - start_logging, 2)}s\")\n", - " print(f\"Total for word '{word}': {round(time() - start_word_time, 2)}s\")\n", + " # print(f\"Attributions computation: {time_all_attrs}s\")\n", + " # print(f\"- Cosine Attr: {round(cosine_timestamp_end - cosine_timestamp, 2)}s\")\n", + " # print(\n", + " # f\"- Prob Diff Attr: {round(prob_diff_timestamp_end - prob_diff_timestamp, 2)}s\"\n", + " # )\n", + " # print(\n", + " # f\"- Token Displacement Attr: {round(token_displacement_timestamp_end - token_displacement_timestamp, 2)}s\"\n", + " # )\n", + " # print(f\"- Attr Logging: {round(end_logging - start_logging, 2)}s\")\n", + " # print(f\"Total for word '{word}': {round(time() - start_word_time, 2)}s\")\n", "\n", - " token_index += len(word_tokens)\n", + " unit_offset += len(unit_tokens)\n", "\n", " print(f\"\\n\\nExp Total: {time() - exp_timestamp}s\\n\\n\")\n", "\n", @@ -209,7 +296,7 @@ " i,\n", " tokenizer.decode(replacement_token_ids)[0],\n", " perturbation_strategy,\n", - " input_sequence,\n", + " input_text,\n", " original_output.message.content,\n", " perturbed_input,\n", " perturbed_output.message.content,\n", @@ -221,127 +308,97 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Chat Completion - Original: 0.73s\n", + "Chat Completion - Original: 1.26s\n", "\n", - "Replaced word 'The': 0.06s - get_replacement_token()\n", + "Replaced word 'ĠThe': 0.26s - get_replacement_token()\n", "Original: The clock shows 9:47 PM. How many minutes 'til 10?\n", - "Perturbed: exp clock shows 9:47 PM. How many minutes 'til 10?\n", - "Chat Completion - Perturbed: 0.73s\n", - "Attributions computation: 0.009429931640625s\n", - "- Cosine Attr: 0.0s\n", - "- Prob Diff Attr: 0.0s\n", - "- Token Displacement Attr: 0.0s\n", - "- Attr Logging: 0.0s\n", - "Total for word 'The': 0.8s\n", + "Perturbed: Limit clock shows 9:47 PM. How many minutes 'til 10?\n", + "Chat Completion - Perturbed: 0.65s\n", + "cosine attributed_tokens ['13', ' minutes', '.']\n", + "prob_diff attributed_tokens ['13', ' minutes', '.']\n", + "token_displacement attributed_tokens ['13', ' minutes', '.']\n", "\n", - "Replaced word 'clock': 0.05s - get_replacement_token()\n", + "Replaced word 'Ġclock': 0.24s - get_replacement_token()\n", "Original: The clock shows 9:47 PM. How many minutes 'til 10?\n", - "Perturbed: Thework shows 9:47 PM. How many minutes 'til 10?\n", - "Chat Completion - Perturbed: 0.91s\n", - "Attributions computation: 0.008873939514160156s\n", - "- Cosine Attr: 0.0s\n", - "- Prob Diff Attr: 0.0s\n", - "- Token Displacement Attr: 0.0s\n", - "- Attr Logging: 0.0s\n", - "Total for word 'clock': 0.97s\n", + "Perturbed: The Penny shows 9:47 PM. How many minutes 'til 10?\n", + "Chat Completion - Perturbed: 4.96s\n", + "cosine attributed_tokens ['13', ' minutes', '.']\n", + "prob_diff attributed_tokens ['13', ' minutes', '.']\n", + "token_displacement attributed_tokens ['13', ' minutes', '.']\n", "\n", - "Replaced word 'shows': 0.05s - get_replacement_token()\n", + "Replaced word 'Ġshows': 0.36s - get_replacement_token()\n", "Original: The clock shows 9:47 PM. How many minutes 'til 10?\n", - "Perturbed: The clock cheesy 9:47 PM. How many minutes 'til 10?\n", - "Chat Completion - Perturbed: 0.37s\n", - "Attributions computation: 0.00640106201171875s\n", - "- Cosine Attr: 0.0s\n", - "- Prob Diff Attr: 0.0s\n", - "- Token Displacement Attr: 0.0s\n", - "- Attr Logging: 0.0s\n", - "Total for word 'shows': 0.43s\n", + "Perturbed: The clockmedia 9:47 PM. How many minutes 'til 10?\n", + "Chat Completion - Perturbed: 1.08s\n", + "cosine attributed_tokens ['13', ' minutes', '.']\n", + "prob_diff attributed_tokens ['13', ' minutes', '.']\n", + "token_displacement attributed_tokens ['13', ' minutes', '.']\n", "\n", - "Replaced word '9:47': 0.12s - get_replacement_token()\n", + "Replaced word 'Ġ9:47': 0.71s - get_replacement_token()\n", "Original: The clock shows 9:47 PM. How many minutes 'til 10?\n", - "Perturbed: The clock shows teaspASON unpredict PM. How many minutes 'til 10?\n", - "Chat Completion - Perturbed: 1.59s\n", - "Attributions computation: 0.022313833236694336s\n", - "- Cosine Attr: 0.0s\n", - "- Prob Diff Attr: 0.0s\n", - "- Token Displacement Attr: 0.0s\n", - "- Attr Logging: 0.01s\n", - "Total for word '9:47': 1.74s\n", + "Perturbed: The clock shows accountingASON unpredict PM. How many minutes 'til 10?\n", + "Chat Completion - Perturbed: 1.38s\n", + "cosine attributed_tokens ['13', ' minutes', '.']\n", + "prob_diff attributed_tokens ['13', ' minutes', '.']\n", + "token_displacement attributed_tokens ['13', ' minutes', '.']\n", "\n", - "Replaced word 'PM.': 0.08s - get_replacement_token()\n", + "Replaced word 'ĠPM.': 0.52s - get_replacement_token()\n", "Original: The clock shows 9:47 PM. How many minutes 'til 10?\n", - "Perturbed: The clock shows 9:47 bombingroc How many minutes 'til 10?\n", - "Chat Completion - Perturbed: 0.94s\n", - "Attributions computation: 0.017086029052734375s\n", - "- Cosine Attr: 0.0s\n", - "- Prob Diff Attr: 0.0s\n", - "- Token Displacement Attr: 0.0s\n", - "- Attr Logging: 0.0s\n", - "Total for word 'PM.': 1.04s\n", + "Perturbed: The clock shows 9:47 deviationroc How many minutes 'til 10?\n", + "Chat Completion - Perturbed: 1.08s\n", + "cosine attributed_tokens ['13', ' minutes', '.']\n", + "prob_diff attributed_tokens ['13', ' minutes', '.']\n", + "token_displacement attributed_tokens ['13', ' minutes', '.']\n", "\n", - "Replaced word 'How': 0.04s - get_replacement_token()\n", + "Replaced word 'ĠHow': 0.25s - get_replacement_token()\n", "Original: The clock shows 9:47 PM. How many minutes 'til 10?\n", - "Perturbed: The clock shows 9:47 PM. empir many minutes 'til 10?\n", - "Chat Completion - Perturbed: 1.1s\n", - "Attributions computation: 0.008667945861816406s\n", - "- Cosine Attr: 0.0s\n", - "- Prob Diff Attr: 0.0s\n", - "- Token Displacement Attr: 0.0s\n", - "- Attr Logging: 0.0s\n", - "Total for word 'How': 1.15s\n", + "Perturbed: The clock shows 9:47 PM. cryptography many minutes 'til 10?\n", + "Chat Completion - Perturbed: 1.2s\n", + "cosine attributed_tokens ['13', ' minutes', '.']\n", + "prob_diff attributed_tokens ['13', ' minutes', '.']\n", + "token_displacement attributed_tokens ['13', ' minutes', '.']\n", "\n", - "Replaced word 'many': 0.04s - get_replacement_token()\n", + "Replaced word 'Ġmany': 0.26s - get_replacement_token()\n", "Original: The clock shows 9:47 PM. How many minutes 'til 10?\n", - "Perturbed: The clock shows 9:47 PM. How pubs minutes 'til 10?\n", - "Chat Completion - Perturbed: 0.9s\n", - "Attributions computation: 0.010850906372070312s\n", - "- Cosine Attr: 0.0s\n", - "- Prob Diff Attr: 0.0s\n", - "- Token Displacement Attr: 0.0s\n", - "- Attr Logging: 0.0s\n", - "Total for word 'many': 0.95s\n", + "Perturbed: The clock shows 9:47 PM. How acidic minutes 'til 10?\n", + "Chat Completion - Perturbed: 1.11s\n", + "cosine attributed_tokens ['13', ' minutes', '.']\n", + "prob_diff attributed_tokens ['13', ' minutes', '.']\n", + "token_displacement attributed_tokens ['13', ' minutes', '.']\n", "\n", - "Replaced word 'minutes': 0.09s - get_replacement_token()\n", + "Replaced word 'Ġminutes': 0.29s - get_replacement_token()\n", "Original: The clock shows 9:47 PM. How many minutes 'til 10?\n", - "Perturbed: The clock shows 9:47 PM. How manyRespakotil 10?\n", - "Chat Completion - Perturbed: 0.73s\n", - "Attributions computation: 0.019688844680786133s\n", - "- Cosine Attr: 0.0s\n", - "- Prob Diff Attr: 0.0s\n", - "- Token Displacement Attr: 0.0s\n", - "- Attr Logging: 0.0s\n", - "Total for word 'minutes': 0.84s\n", + "Perturbed: The clock shows 9:47 PM. How manystyles 'til 10?\n", + "Chat Completion - Perturbed: 0.87s\n", + "cosine attributed_tokens ['13', ' minutes', '.']\n", + "prob_diff attributed_tokens ['13', ' minutes', '.']\n", + "token_displacement attributed_tokens ['13', ' minutes', '.']\n", "\n", - "Replaced word ''til': 0.09s - get_replacement_token()\n", + "Replaced word 'Ġ'til': 0.55s - get_replacement_token()\n", "Original: The clock shows 9:47 PM. How many minutes 'til 10?\n", - "Perturbed: The clock shows 9:47 PM. How many minutes'refugees Naturally?\n", - "Chat Completion - Perturbed: 0.69s\n", - "Attributions computation: 0.022047996520996094s\n", - "- Cosine Attr: 0.0s\n", - "- Prob Diff Attr: 0.0s\n", - "- Token Displacement Attr: 0.0s\n", - "- Attr Logging: 0.0s\n", - "Total for word ''til': 0.8s\n", + "Perturbed: The clock shows 9:47 PM. How many minutes Wah clich 10?\n", + "Chat Completion - Perturbed: 0.86s\n", + "cosine attributed_tokens ['13', ' minutes', '.']\n", + "prob_diff attributed_tokens ['13', ' minutes', '.']\n", + "token_displacement attributed_tokens ['13', ' minutes', '.']\n", "\n", - "Replaced word '10?': 0.09s - get_replacement_token()\n", + "Replaced word 'Ġ10?': 0.48s - get_replacement_token()\n", "Original: The clock shows 9:47 PM. How many minutes 'til 10?\n", - "Perturbed: The clock shows 9:47 PM. How many minutes 'til 10 setbacks susceptible\n", - "Chat Completion - Perturbed: 0.82s\n", - "Attributions computation: 0.019437074661254883s\n", - "- Cosine Attr: 0.0s\n", - "- Prob Diff Attr: 0.0s\n", - "- Token Displacement Attr: 0.0s\n", - "- Attr Logging: 0.0s\n", - "Total for word '10?': 0.93s\n", + "Perturbed: The clock shows 9:47 PM. How many minutes 'til lawsuit susceptible\n", + "Chat Completion - Perturbed: 0.61s\n", + "cosine attributed_tokens ['13', ' minutes', '.']\n", + "prob_diff attributed_tokens ['13', ' minutes', '.']\n", + "token_displacement attributed_tokens ['13', ' minutes', '.']\n", "\n", "\n", - "Exp Total: 9.65583610534668s\n", + "Exp Total: 18.051724910736084s\n", "\n", "\n", "The clock shows 9:47 PM. How many minutes 'til 10? ('13 minutes.',)\n" @@ -349,8 +406,7 @@ } ], "source": [ - "input_sequences = [\n", - " \"The clock shows 9:47 PM. How many minutes 'til 10?\"]\n", + "input_texts = [\"The clock shows 9:47 PM. How many minutes 'til 10?\"]\n", "# \"The building is 132 meters tall. How tall is the building?\",\n", "# \"The package weighs 8.6 kilograms. How much does the package weigh?\",\n", "# \"The thermometer reads 23 degrees Celsius. What is the temperature according to the thermometer?\",\n", @@ -363,26 +419,27 @@ "# ]\n", "\n", "\n", - "for input_sequence in input_sequences:\n", + "for input_text in input_texts:\n", " for perturbation_strategy in [\"distant\"]:\n", - " original_output = calculate_token_importance_in_sequence(\n", - " input_sequence,\n", + " original_output = calculate_token_importance(\n", + " input_text,\n", " model,\n", " tokenizer,\n", " perturbation_strategy,\n", " attribution_strategies=[\"cosine\", \"prob_diff\", \"token_displacement\"],\n", " logger=logger,\n", + " perturb_word_wise=True,\n", " )\n", "\n", " print(\n", - " input_sequence,\n", + " input_text,\n", " original_output,\n", " )" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -410,6 +467,7 @@ " original_input\n", " original_output\n", " perturbation_strategy\n", + " perturb_word_wise\n", " duration\n", " \n", " \n", @@ -420,7 +478,8 @@ " The clock shows 9:47 PM. How many minutes 'til...\n", " 13 minutes.\n", " distant\n", - " 9.658044\n", + " True\n", + " 18.055734\n", " \n", " \n", "\n", @@ -430,8 +489,8 @@ " exp_id original_input original_output \\\n", "0 1 The clock shows 9:47 PM. How many minutes 'til... 13 minutes. \n", "\n", - " perturbation_strategy duration \n", - "0 distant 9.658044 " + " perturbation_strategy perturb_word_wise duration \n", + "0 distant True 18.055734 " ] }, "metadata": {}, @@ -444,162 +503,159 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "
 exp_idattribution_strategyperturbation_strategytoken_1token_2token_3token_4token_5token_6token_7token_8token_9token_10token_11token_12token_13token_14token_15token_16exp_idattribution_strategyperturbation_strategyperturb_word_wisetoken_1token_2token_3token_4token_5token_6token_7token_8token_9token_10token_11token_12token_13token_14token_15
01cosinedistantThe\n", - "0.14clock\n", - "0.14shows\n", - "0.109\n", - "0.20:\n", - "0.2047\n", - "0.20PM\n", - "0.17.\n", - "0.17How\n", - "0.14many\n", - "0.14min\n", - "0.14utes\n", - "0.14't\n", - "0.10il\n", - "0.1010\n", + " 01cosinedistantTrueThe\n", "0.00?\n", + " clock\n", "0.00
11prob_diffdistantThe\n", - "0.76clock\n", - "0.70shows\n", - "0.289\n", - "0.80:\n", - "0.8047\n", - "0.80PM\n", - "0.78.\n", - "0.78How\n", - "0.79many\n", - "0.72min\n", - "0.80utes\n", + " shows\n", + "0.009\n", + "0.15:\n", + "0.1547\n", + "0.15PM\n", + "0.15.\n", + "0.15How\n", + "0.13many\n", + "0.13minutes\n", + "0.13'\n", + "0.00til\n", + "0.0010\n", + "0.11?\n", + "0.11
11prob_diffdistantTrueThe\n", + "0.08clock\n", + "0.12shows\n", + "0.119\n", + "0.82:\n", + "0.8247\n", + "0.82PM\n", + "0.81.\n", + "0.81How\n", + "0.81many\n", "0.80't\n", - "0.29il\n", - "0.2910\n", - "0.10?\n", - "0.10
21token_displacementdistantThe\n", + " minutes\n", + "0.82'\n", + "0.05til\n", + "0.0510\n", + "0.24?\n", + "0.24
21token_displacementdistantTrueThe\n", + "0.00clock\n", + "0.00shows\n", + "0.009\n", + "19.33:\n", + "19.3347\n", + "19.33PM\n", + "12.67.\n", + "12.67How\n", "13.67clock\n", - "13.00shows\n", - "6.679\n", - "20.00:\n", - "20.0047\n", - "20.00PM\n", - "13.33.\n", - "13.33How\n", + " many\n", "13.67many\n", + " minutes\n", "13.67min\n", - "16.67utes\n", - "16.67't\n", - "6.67il\n", - "6.6710\n", + " '\n", "0.00?\n", + " til\n", "0.0010\n", + "6.67?\n", + "6.67
\n" ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -612,7 +668,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -627,134 +683,128 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 The (0)clock (1)shows (2)13 (0) minutes (1). (2)
The (0)0.6003460.7595860.701984The (0)0.0000000.0000000.000000
clock (1)0.6003460.7595860.701984clock (1)0.0000000.0000000.000000
shows (2)-0.000000-0.0000001.000000shows (2)0.0000000.0000000.000000
9 (3)0.6003460.7595860.7474809 (3)0.6889560.7595860.747480
: (4)0.6003460.7595860.747480: (4)0.6889560.7595860.747480
47 (5)0.6003460.7595860.74748047 (5)0.6889560.7595860.747480
PM (6)0.6003460.7595860.701984PM (6)0.6889560.7595860.701984
. (7)0.6003460.7595860.701984. (7)0.6889560.7595860.701984
How (8)0.6003460.7595860.701984How (8)0.6889560.7595860.701984
many (9)0.6003460.7595860.701984many (9)0.6889560.7595860.701984
min (10)0.6003460.7595860.701984minutes (10)0.6889560.7595860.701984
utes (11)0.6003460.7595860.701984' (11)0.0000000.0000000.000000
't (12)-0.000000-0.0000001.000000til (12)0.0000000.0000000.000000
il (13)-0.000000-0.0000001.00000010 (13)0.0000000.0000001.000000
10 (14)-0.000000-0.0000000.000000
? (15)-0.000000-0.0000000.000000? (14)0.0000000.0000001.000000
\n" ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -772,154 +822,160 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 13 (0) minutes (1). (2)13 (0) minutes (1). (2)
The (0)0.6467140.9999190.619983
clock (1)0.4897120.9999190.619983The (0)0.2398910.0000120.004187
shows (2)0.2316510.0000330.619983clock (1)0.3391380.0000610.025875
9 (3)0.7858890.9999190.619983shows (2)0.2999490.0000290.024715
: (4)0.7858890.9999190.6199839 (3)0.8179920.9998720.640487
47 (5)0.7858890.9999190.619983: (4)0.8179920.9998720.640487
PM (6)0.7073770.9999190.61998347 (5)0.8179920.9998720.640487
. (7)0.7073770.9999190.619983PM (6)0.8043870.9998720.640487
How (8)0.7543070.9999190.619983. (7)0.8043870.9998720.640487
many (9)0.5350380.9999190.619983How (8)0.7876020.9998720.640487
min (10)0.7855920.9999190.619983many (9)0.7582780.9998720.640487
utes (11)0.7855920.9999190.619983minutes (10)0.8140920.9998720.640487
't (12)0.2572380.0007040.619983' (11)0.0423290.0012050.102655
il (13)0.2572380.0007040.619983til (12)0.0423290.0012050.102655
10 (14)0.1135660.0001290.19463410 (13)0.0824090.0001640.640487
? (15)0.1135660.0001290.194634? (14)0.0824090.0001640.640487
\n" ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -937,142 +993,132 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 13 (0) minutes (1). (2)13 (0) minutes (1). (2)
The (0)1.00000020.00000020.000000The (0)0.0000000.0000000.000000
clock (1)1.00000018.00000020.000000clock (1)0.0000000.0000000.000000
shows (2)0.0000000.00000020.000000shows (2)0.0000000.0000000.000000
9 (3)20.00000020.00000020.0000009 (3)20.00000018.00000020.000000
: (4)20.00000020.00000020.000000: (4)20.00000018.00000020.000000
47 (5)20.00000020.00000020.00000047 (5)20.00000018.00000020.000000
PM (6)1.00000019.00000020.000000PM (6)1.00000017.00000020.000000
. (7)1.00000019.00000020.000000. (7)1.00000017.00000020.000000
How (8)1.00000020.00000020.000000How (8)1.00000020.00000020.000000
many (9)1.00000020.00000020.000000many (9)1.00000020.00000020.000000
min (10)10.00000020.00000020.000000minutes (10)1.00000020.00000020.000000
utes (11)10.00000020.00000020.000000' (11)0.0000000.0000000.000000
't (12)0.0000000.00000020.000000til (12)0.0000000.0000000.000000
il (13)0.0000000.00000020.00000010 (13)0.0000000.00000020.000000
10 (14)0.0000000.0000000.000000
? (15)0.0000000.0000000.000000? (14)0.0000000.00000020.000000
\n" ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -1085,14 +1131,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Attribution matrix for cosine with perturbation strategy distant:\n", + "Attribution matrix for prob_diff with perturbation strategy distant:\n", "Input Tokens (Rows) vs. Output Tokens (Columns)\n" ] }, @@ -1100,240 +1146,160 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 The (0)clock (1)shows (2)9 (3): (4)47 (5)PM (6). (7)13 (0) minutes (1). (2)
The (0)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
clock (1)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
shows (2)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
9 (3)0.000000-0.000000-0.0000000.2891770.0000000.2749360.0000000.000000
: (4)0.000000-0.000000-0.0000000.2891770.0000000.2749360.0000000.000000
47 (5)0.000000-0.000000-0.0000000.2891770.0000000.2749360.0000000.000000
PM (6)0.000000-0.0000000.0000000.0000000.000000-0.0000000.8215001.000000
. (7)0.000000-0.0000000.0000000.0000000.000000-0.0000000.8215001.000000
What (8)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
time (9)0.000000-0.0000000.0000000.4880540.000000-0.0000000.8215001.000000
does (10)0.5605800.8569700.6968210.7463261.0000001.0000001.0000001.000000
the (11)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
clock (12)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
show (13)0.000000-0.0000000.8286130.7540330.5130170.6937780.7817760.758650
? (14)0.000000-0.0000000.8286130.7540330.5130170.6937780.7817760.758650The (0)0.2398910.0000120.004187
clock (1)0.3391380.0000610.025875
shows (2)0.2999490.0000290.024715
9 (3)0.8179920.9998720.640487
: (4)0.8179920.9998720.640487
47 (5)0.8179920.9998720.640487
PM (6)0.8043870.9998720.640487
. (7)0.8043870.9998720.640487
How (8)0.7876020.9998720.640487
many (9)0.7582780.9998720.640487
minutes (10)0.8140920.9998720.640487
' (11)0.0423290.0012050.102655
til (12)0.0423290.0012050.102655
10 (13)0.0824090.0001640.640487
? (14)0.0824090.0001640.640487
\n" ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -1341,12 +1307,12 @@ } ], "source": [ - "logger.print_attribution_matrix(1, \"cosine\")" + "logger.print_attribution_matrix(exp_id=1, attribution_strategy=\"prob_diff\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -1361,583 +1327,128 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - "
 The (0)clock (1)shows (2)9 (3): (4)47 (5)PM (6). (7)13 (0) minutes (1). (2)
The (0)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
clock (1)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
shows (2)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
9 (3)0.000000-0.000000-0.0000000.2891770.0000000.2749360.0000000.000000
: (4)0.000000-0.000000-0.0000000.2891770.0000000.2749360.0000000.000000
47 (5)0.000000-0.000000-0.0000000.2891770.0000000.2749360.0000000.000000
PM (6)0.000000-0.0000000.0000000.0000000.000000-0.0000000.8215001.000000
. (7)0.000000-0.0000000.0000000.0000000.000000-0.0000000.8215001.000000
What (8)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
time (9)0.000000-0.0000000.0000000.4880540.000000-0.0000000.8215001.000000
does (10)0.5605800.8569700.6968210.7463261.0000001.0000001.0000001.000000
the (11)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
clock (12)0.000000-0.000000-0.0000000.0000000.000000-0.0000000.0000000.000000
show (13)0.000000-0.0000000.8286130.7540330.5130170.6937780.7817760.758650
? (14)0.000000-0.0000000.8286130.7540330.5130170.6937780.7817760.758650The (0)0.0000000.0000000.000000
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Attribution matrix for prob_diff with perturbation strategy distant:\n", - "Input Tokens (Rows) vs. Output Tokens (Columns)\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 The (0) clock (1) shows (2) (3)9 (4): (5)47 (6) PM (7). (8)clock (1)0.0000000.0000000.000000
shows (2)0.0000000.0000000.000000
9 (3)0.6889560.7595860.747480
: (4)0.6889560.7595860.747480
47 (5)0.6889560.7595860.747480
PM (6)0.6889560.7595860.701984
The (0)0.0644480.0078690.0030850.0029890.0001080.0001280.0001180.0000540.007132
clock (1)0.1364580.0339540.0038100.0027360.0002850.0001010.0014140.0004320.000383
shows (2)0.1516680.0068040.0092900.0061510.0000790.0001410.0007330.0001250.016142
9 (3)0.1917580.0081940.0166630.0402860.9510410.0003150.9582580.0001380.001367
: (4)0.1917580.0081940.0166630.0402860.9510410.0003150.9582580.0001380.001367
47 (5)0.1917580.0081940.0166630.0402860.9510410.0003150.9582580.0001380.001367
PM (6)0.1710690.0001170.0010730.0001210.0001130.0001390.0000160.9998580.994526
. (7)0.1710690.0001170.0010730.0001210.0001130.0001390.0000160.9998580.994526
What (8)0.1981330.0078140.1484310.0184170.0030660.0000650.0031200.0016210.126054
time (9)0.1968400.0053410.0098950.0773050.8847650.0000520.0001320.9254030.994526
does (10)0.3069420.9893260.9975370.9993540.9996700.9998560.9998990.9998960.994526
the (11)0.0792020.4819940.1167430.1952750.0086850.0001060.0019670.0051890.018881
clock (12)0.1322340.0822220.0655670.0259150.0041240.0001140.0001780.0005700.028210
show (13)0.0068050.0140620.9807600.9993610.9996700.9998560.9977560.9998650.994526
? (14)0.0068050.0140620.9807600.9993610.9996700.9998560.9977560.9998650.994526. (7)0.6889560.7595860.701984
How (8)0.6889560.7595860.701984
many (9)0.6889560.7595860.701984
minutes (10)0.6889560.7595860.701984
' (11)0.0000000.0000000.000000
til (12)0.0000000.0000000.000000
10 (13)0.0000000.0000001.000000
? (14)0.0000000.0000001.000000
\n" ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -1945,530 +1456,39 @@ } ], "source": [ - "logger.print_attribution_matrix(2, attribution_strategy=\"cosine\")\n", - "logger.print_attribution_matrix(2, attribution_strategy=\"prob_diff\")" + "logger.print_attribution_matrix(1, \"cosine\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Attribution matrix for cosine with perturbation strategy distant:\n", - "Input Tokens (Rows) vs. Output Tokens (Columns)\n" + "ename": "IndexError", + "evalue": "index 0 is out of bounds for axis 0 with size 0", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mlogger\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprint_attribution_matrix\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattribution_strategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcosine\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m logger\u001b[38;5;241m.\u001b[39mprint_attribution_matrix(\u001b[38;5;241m2\u001b[39m, attribution_strategy\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprob_diff\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/Projects/llm-attribution/attribution/experiment_logger.py:184\u001b[0m, in \u001b[0;36mExperimentLogger.print_attribution_matrix\u001b[0;34m(self, exp_id, attribution_strategy)\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 176\u001b[0m \u001b[38;5;66;03m# Filter the data for the specific experiment and attribution strategy\u001b[39;00m\n\u001b[1;32m 177\u001b[0m exp_data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdf_token_attribution_matrix[\n\u001b[1;32m 178\u001b[0m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdf_token_attribution_matrix[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexp_id\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m==\u001b[39m exp_id)\n\u001b[1;32m 179\u001b[0m \u001b[38;5;241m&\u001b[39m (\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 182\u001b[0m )\n\u001b[1;32m 183\u001b[0m ]\n\u001b[0;32m--> 184\u001b[0m perturbation_strategy \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdf_experiments\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloc\u001b[49m\u001b[43m[\u001b[49m\n\u001b[1;32m 185\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdf_experiments\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mexp_id\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mexp_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mperturbation_strategy\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m 186\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# Create the pivot table for the matrix\u001b[39;00m\n\u001b[1;32m 189\u001b[0m matrix \u001b[38;5;241m=\u001b[39m exp_data\u001b[38;5;241m.\u001b[39mpivot(\n\u001b[1;32m 190\u001b[0m index\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_token_pos\u001b[39m\u001b[38;5;124m\"\u001b[39m, columns\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_token_pos\u001b[39m\u001b[38;5;124m\"\u001b[39m, values\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattr_score\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m )\n", + "\u001b[0;31mIndexError\u001b[0m: index 0 is out of bounds for axis 0 with size 0" ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
 The (0)building (1)is (2)132 (3)meters (4)tall (5). (6)
The (0)0.0000000.000000-0.0000000.0000000.0000000.0000000.000000
building (1)0.2266590.7050130.5200120.8062860.8554250.7560680.466330
is (2)0.0000000.000000-0.0000000.0000000.0000000.0000000.000000
132 (3)0.2266590.7703410.6210990.8078230.7857450.7674780.480238
met (4)0.2266590.7050130.5200120.7880100.7989890.7719130.412676
ers (5)0.2266590.7050130.5200120.7880100.7989890.7719130.412676
tall (6)0.0000000.000000-0.0000000.0000000.0000000.0000000.000000
. (7)0.0000000.000000-0.0000000.0000000.0000000.0000000.000000
How (8)0.0000000.000000-0.0000000.0000000.0000000.0000000.000000
tall (9)0.0000000.000000-0.0000000.0000000.0000000.0000000.000000
is (10)0.0000000.000000-0.0000000.0000000.0000000.0000000.000000
the (11)0.0000000.7585920.4103120.6982630.8183600.8033580.749395
building (12)0.2266590.7050130.5200120.7880100.7989890.7785790.796747
? (13)0.2266590.7050130.5200120.7880100.7989890.7785790.796747
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Attribution matrix for prob_diff with perturbation strategy distant:\n", - "Input Tokens (Rows) vs. Output Tokens (Columns)\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
 The (0) building (1) is (2) (3)132 (4) meters (5) tall (6). (7)
The (0)0.0187360.0001130.0003090.0000790.0000530.0000060.0001210.000022
building (1)0.7777440.9990230.9998830.9999550.9999620.9999890.9998220.999864
is (2)0.0153640.0007130.0083160.0001660.0000160.0000070.0001460.000301
132 (3)0.9856850.9990230.9998830.9998930.9999620.9999890.9998220.999864
met (4)0.9634650.9990230.9998830.9999420.9999620.9999890.9998220.999864
ers (5)0.9634650.9990230.9998830.9999420.9999620.9999890.9998220.999864
tall (6)0.0574430.0000390.0000010.0012330.0002110.0000130.0010110.000231
. (7)0.0574430.0000390.0000010.0012330.0002110.0000130.0010110.000231
How (8)0.0173490.0130340.0062530.3514490.0245280.0001970.0341640.228418
tall (9)0.0283250.0241040.0021900.0045850.0000510.0000160.0001000.001761
is (10)0.3039730.4399570.0093630.0010680.0001150.0000570.0000240.001697
the (11)0.4550510.9877370.7951230.9999350.9999620.9999890.9998220.989866
building (12)0.9727750.9990230.9998830.9999550.9999620.9999890.9998220.999864
? (13)0.9727750.9990230.9998830.9999550.9999620.9999890.9998220.999864
\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" } ], "source": [ - "logger.print_attribution_matrix(3, attribution_strategy=\"cosine\")\n", - "logger.print_attribution_matrix(3, attribution_strategy=\"prob_diff\")" + "logger.print_attribution_matrix(2, attribution_strategy=\"cosine\")\n", + "logger.print_attribution_matrix(2, attribution_strategy=\"prob_diff\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -2493,7 +1513,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.9.19" } }, "nbformat": 4,