diff --git a/attribution/attribution_metrics.py b/attribution/attribution_metrics.py
index 2344094..b42eebf 100644
--- a/attribution/attribution_metrics.py
+++ b/attribution/attribution_metrics.py
@@ -109,34 +109,40 @@ def cosine_similarity_attribution(
tokenizer: PreTrainedTokenizer,
) -> Tuple[float, np.ndarray]:
# Extract embeddings
- initial_sentence_emb, initial_token_embs = get_sentence_embeddings(
+ initial_output_sentence_emb, initial_output_token_embs = get_sentence_embeddings(
original_output_choice.message.content, model, tokenizer
)
- perturbed_sentence_emb, perturbed_token_embs = get_sentence_embeddings(
- perturbed_output_choice.message.content, model, tokenizer
+ perturbed_output_sentence_emb, perturbed_output_token_embs = (
+ get_sentence_embeddings(
+ perturbed_output_choice.message.content, model, tokenizer
+ )
)
# Reshape embeddings
- initial_sentence_emb = initial_sentence_emb.reshape(1, -1)
- perturbed_sentence_emb = perturbed_sentence_emb.reshape(1, -1)
+ initial_output_sentence_emb = initial_output_sentence_emb.reshape(1, -1)
+ perturbed_output_sentence_emb = perturbed_output_sentence_emb.reshape(1, -1)
# Calculate similarities
self_similarity = float(
- cosine_similarity(initial_sentence_emb, initial_sentence_emb)
+ cosine_similarity(initial_output_sentence_emb, initial_output_sentence_emb)
)
sentence_similarity = float(
- cosine_similarity(initial_sentence_emb, perturbed_sentence_emb)
+ cosine_similarity(initial_output_sentence_emb, perturbed_output_sentence_emb)
)
# Calculate token similarities for shared length
- shared_length = min(initial_token_embs.shape[0], perturbed_token_embs.shape[0])
+ shared_length = min(
+ initial_output_token_embs.shape[0], perturbed_output_token_embs.shape[0]
+ )
token_similarities_shared = cosine_similarity(
- initial_token_embs[:shared_length], perturbed_token_embs[:shared_length]
+ initial_output_token_embs[:shared_length],
+ perturbed_output_token_embs[:shared_length],
).diagonal()
# Pad token similarities to match initial token embeddings shape
token_similarities = np.pad(
- token_similarities_shared, (0, initial_token_embs.shape[0] - shared_length)
+ token_similarities_shared,
+ (0, initial_output_token_embs.shape[0] - shared_length),
)
# Return difference in sentence similarity and token similarities
diff --git a/attribution/experiment_logger.py b/attribution/experiment_logger.py
index d04b2ac..1265338 100644
--- a/attribution/experiment_logger.py
+++ b/attribution/experiment_logger.py
@@ -15,6 +15,7 @@ def __init__(self, experiment_id=0):
"original_input",
"original_output",
"perturbation_strategy",
+ "perturb_word_wise",
"duration",
]
)
@@ -51,7 +52,11 @@ def __init__(self, experiment_id=0):
)
def start_experiment(
- self, original_input: str, original_output: str, perturbation_strategy: str
+ self,
+ original_input: str,
+ original_output: str,
+ perturbation_strategy: str,
+ perturb_word_wise: bool,
):
self.experiment_id += 1
self.experiment_start_time = time.time()
@@ -60,6 +65,7 @@ def start_experiment(
"original_input": original_input,
"original_output": original_output,
"perturbation_strategy": perturbation_strategy,
+ "perturb_word_wise": perturb_word_wise,
"duration": None,
}
@@ -140,11 +146,15 @@ def print_sentence_attribution(self):
perturbation_strategy = self.df_experiments.loc[
self.df_experiments["exp_id"] == exp_id, "perturbation_strategy"
].values[0]
+ perturb_word_wise = self.df_experiments.loc[
+ self.df_experiments["exp_id"] == exp_id, "perturb_word_wise"
+ ].values[0]
sentence_data = {
"exp_id": exp_id,
"attribution_strategy": attr_strat,
"perturbation_strategy": perturbation_strategy,
+ "perturb_word_wise": perturb_word_wise,
}
sentence_data.update(
{f"token_{i+1}": token_attr for i, token_attr in enumerate(token_attrs)}
diff --git a/research/api_llm_attribution.ipynb b/research/api_llm_attribution.ipynb
index ca10f98..91ff9ba 100644
--- a/research/api_llm_attribution.ipynb
+++ b/research/api_llm_attribution.ipynb
@@ -2,9 +2,18 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 1,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/sebastian/Projects/llm-attribution/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
"source": [
"import os\n",
"import timeit\n",
@@ -41,7 +50,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -51,12 +60,12 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"model = GPT2LMHeadModel.from_pretrained(\"gpt2\") # or any other checkpoint\n",
- "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
+ "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\", add_prefix_space=True)\n",
"\n",
"word_token_embeddings = model.transformer.wte.weight.detach().numpy()\n",
"position_embeddings = model.transformer.wpe.weight.detach().numpy()\n",
@@ -67,7 +76,59 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 70,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tokenizing the entire input text:\n",
+ "[['Ġthe'], ['Ġcamel'], ['Ġcase'], ['Ġomnip'], ['otent']]\n",
+ "[[262], [41021], [1339], [40046], [33715]]\n",
+ "\n",
+ "Tokenizing word by word:\n",
+ "[['Ġthe'], ['Ġcamel'], ['Ġcase'], ['Ġomnip', 'otent']]\n",
+ "[[262], [41021], [1339], [40046, 33715]]\n",
+ "\n",
+ "[262, 41021, 1339]\n",
+ "[[262], [41021], [1339]]\n",
+ "[262, 41021, 1339]\n"
+ ]
+ }
+ ],
+ "source": [
+ "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\", add_prefix_space=True)\n",
+ "\n",
+ "# Test input text\n",
+ "input_text = \"the camel case omnipotent\"\n",
+ "\n",
+ "# Tokenize the entire input text\n",
+ "input_tokens = [[token] for token in tokenizer.tokenize(input_text)]\n",
+ "input_token_ids = [\n",
+ " [token_id] for token_id in tokenizer.encode(input_text, add_special_tokens=False)\n",
+ "]\n",
+ "print(\"Tokenizing the entire input text:\")\n",
+ "print(input_tokens)\n",
+ "print(input_token_ids)\n",
+ "\n",
+ "# Tokenize word by word\n",
+ "words = input_text.split()\n",
+ "word_tokens = [tokenizer.tokenize(word) for word in words]\n",
+ "word_token_ids = [tokenizer.encode(word, add_special_tokens=False) for word in words]\n",
+ "print(\"\\nTokenizing word by word:\")\n",
+ "print(word_tokens)\n",
+ "print(word_token_ids)\n",
+ "\n",
+ "print()\n",
+ "print([token_id for sublist in input_token_ids[:3] for token_id in sublist])\n",
+ "print([sublist for sublist in input_token_ids[:3]])\n",
+ "print([token_id for sublist in input_token_ids[:3] for token_id in sublist])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -83,8 +144,8 @@
" return response.choices[0]\n",
"\n",
"\n",
- "def calculate_token_importance_in_sequence(\n",
- " input_sequence: str,\n",
+ "def calculate_token_importance(\n",
+ " input_text: str,\n",
" model: PreTrainedModel,\n",
" tokenizer: PreTrainedTokenizer,\n",
" perturbation_strategy: str = \"fixed\",\n",
@@ -94,55 +155,81 @@
" \"token_displacement\",\n",
" ],\n",
" logger: ExperimentLogger = None,\n",
+ " perturb_word_wise: bool = False,\n",
"):\n",
" timestamp = time()\n",
- " tokens = tokenizer.tokenize(input_sequence)\n",
- " token_ids = tokenizer.encode(input_sequence, add_special_tokens=False)\n",
- " original_output = get_model_output(input_sequence)\n",
+ " original_output = get_model_output(input_text)\n",
" print(f\"Chat Completion - Original: {round(time() - timestamp, 2)}s\")\n",
"\n",
" if logger:\n",
" logger.start_experiment(\n",
- " input_sequence, original_output.message.content, perturbation_strategy\n",
+ " input_text,\n",
+ " original_output.message.content,\n",
+ " perturbation_strategy,\n",
+ " perturb_word_wise,\n",
" )\n",
"\n",
" exp_timestamp = time()\n",
- " words = input_sequence.split()\n",
- " token_index = 0\n",
- " for word in words:\n",
- " word_tokens = tokenizer.tokenize(word)\n",
- " word_token_ids = tokenizer.encode(word, add_special_tokens=False)\n",
+ "\n",
+ " # A unit is either a word or a single token\n",
+ " unit_offset = 0\n",
+ " if perturb_word_wise:\n",
+ " words = input_text.split()\n",
+ " tokens_per_unit = [tokenizer.tokenize(word) for word in words]\n",
+ " token_ids_per_unit = [\n",
+ " tokenizer.encode(word, add_special_tokens=False) for word in words\n",
+ " ]\n",
+ " else:\n",
+ " tokens_per_unit = [[token] for token in tokenizer.tokenize(input_text)]\n",
+ " token_ids_per_unit = [\n",
+ " [token_id]\n",
+ " for token_id in tokenizer.encode(input_text, add_special_tokens=False)\n",
+ " ]\n",
+ "\n",
+ " for i_unit, unit_tokens in enumerate(tokens_per_unit):\n",
" start_word_time = time()\n",
" replacement_token_ids = [\n",
" get_replacement_token(\n",
- " word_token_ids[i],\n",
+ " token_id,\n",
" perturbation_strategy,\n",
" word_token_embeddings,\n",
" tokenizer,\n",
" )\n",
- " for i in range(len(word_tokens))\n",
+ " for token_id in token_ids_per_unit[i_unit]\n",
" ]\n",
" print(\n",
- " f\"\\nReplaced word '{word}': {round(time() - start_word_time, 2)}s - get_replacement_token()\"\n",
+ " f\"\\nReplaced word '{''.join(unit_tokens)}': {round(time() - start_word_time, 2)}s - get_replacement_token()\"\n",
" )\n",
"\n",
" # Replace the current word with the new tokens\n",
+ " left_token_ids = [\n",
+ " token_id\n",
+ " for unit_token_ids in token_ids_per_unit[:i_unit]\n",
+ " for token_id in unit_token_ids\n",
+ " ]\n",
+ " right_token_ids = [\n",
+ " token_id\n",
+ " for unit_token_ids in token_ids_per_unit[i_unit + 1 :]\n",
+ " for token_id in unit_token_ids\n",
+ " ]\n",
" perturbed_input = tokenizer.decode(\n",
- " token_ids[:token_index]\n",
- " + replacement_token_ids\n",
- " + token_ids[token_index + len(word_tokens) :]\n",
+ " left_token_ids + replacement_token_ids + right_token_ids\n",
" )\n",
"\n",
" # Get the output logprobs for the perturbed input\n",
" timestamp = time()\n",
- " print('Original: ',input_sequence)\n",
- " print('Perturbed: ',perturbed_input)\n",
+ " print(\"Original: \", input_text)\n",
+ " print(\"Perturbed: \", perturbed_input)\n",
" perturbed_output = get_model_output(perturbed_input)\n",
" print(f\"Chat Completion - Perturbed: {round(time() - timestamp, 2)}s\")\n",
"\n",
" timestamp = time()\n",
" for attribution_strategy in attribution_strategies:\n",
- " attributed_tokens = tokens\n",
+ " attributed_tokens = [\n",
+ " token_logprob.token\n",
+ " for token_logprob in original_output.logprobs.content\n",
+ " ]\n",
+ " print(attribution_strategy, \"attributed_tokens\", attributed_tokens)\n",
" if attribution_strategy == \"cosine\":\n",
" cosine_timestamp = time()\n",
" sentence_attr, token_attributions = cosine_similarity_attribution(\n",
@@ -172,35 +259,35 @@
"\n",
" if logger:\n",
" start_logging = time()\n",
- " for i in range(len(word_tokens)):\n",
+ " for i, unit_token in enumerate(unit_tokens):\n",
" logger.log_input_token_attribution(\n",
" attribution_strategy,\n",
- " token_index + i,\n",
- " word_tokens[i],\n",
+ " unit_offset + i,\n",
+ " unit_token,\n",
" float(sentence_attr),\n",
" )\n",
" for j, attr_score in enumerate(token_attributions):\n",
" logger.log_token_attribution_matrix(\n",
" attribution_strategy,\n",
- " token_index + i,\n",
+ " unit_offset + i,\n",
" j,\n",
" attributed_tokens[j],\n",
" attr_score.squeeze(),\n",
" )\n",
" end_logging = time()\n",
" time_all_attrs = time() - timestamp\n",
- " print(f\"Attributions computation: {time_all_attrs}s\")\n",
- " print(f\"- Cosine Attr: {round(cosine_timestamp_end - cosine_timestamp, 2)}s\")\n",
- " print(\n",
- " f\"- Prob Diff Attr: {round(prob_diff_timestamp_end - prob_diff_timestamp, 2)}s\"\n",
- " )\n",
- " print(\n",
- " f\"- Token Displacement Attr: {round(token_displacement_timestamp_end - token_displacement_timestamp, 2)}s\"\n",
- " )\n",
- " print(f\"- Attr Logging: {round(end_logging - start_logging, 2)}s\")\n",
- " print(f\"Total for word '{word}': {round(time() - start_word_time, 2)}s\")\n",
+ " # print(f\"Attributions computation: {time_all_attrs}s\")\n",
+ " # print(f\"- Cosine Attr: {round(cosine_timestamp_end - cosine_timestamp, 2)}s\")\n",
+ " # print(\n",
+ " # f\"- Prob Diff Attr: {round(prob_diff_timestamp_end - prob_diff_timestamp, 2)}s\"\n",
+ " # )\n",
+ " # print(\n",
+ " # f\"- Token Displacement Attr: {round(token_displacement_timestamp_end - token_displacement_timestamp, 2)}s\"\n",
+ " # )\n",
+ " # print(f\"- Attr Logging: {round(end_logging - start_logging, 2)}s\")\n",
+ " # print(f\"Total for word '{word}': {round(time() - start_word_time, 2)}s\")\n",
"\n",
- " token_index += len(word_tokens)\n",
+ " unit_offset += len(unit_tokens)\n",
"\n",
" print(f\"\\n\\nExp Total: {time() - exp_timestamp}s\\n\\n\")\n",
"\n",
@@ -209,7 +296,7 @@
" i,\n",
" tokenizer.decode(replacement_token_ids)[0],\n",
" perturbation_strategy,\n",
- " input_sequence,\n",
+ " input_text,\n",
" original_output.message.content,\n",
" perturbed_input,\n",
" perturbed_output.message.content,\n",
@@ -221,127 +308,97 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Chat Completion - Original: 0.73s\n",
+ "Chat Completion - Original: 1.26s\n",
"\n",
- "Replaced word 'The': 0.06s - get_replacement_token()\n",
+ "Replaced word 'ĠThe': 0.26s - get_replacement_token()\n",
"Original: The clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Perturbed: exp clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Chat Completion - Perturbed: 0.73s\n",
- "Attributions computation: 0.009429931640625s\n",
- "- Cosine Attr: 0.0s\n",
- "- Prob Diff Attr: 0.0s\n",
- "- Token Displacement Attr: 0.0s\n",
- "- Attr Logging: 0.0s\n",
- "Total for word 'The': 0.8s\n",
+ "Perturbed: Limit clock shows 9:47 PM. How many minutes 'til 10?\n",
+ "Chat Completion - Perturbed: 0.65s\n",
+ "cosine attributed_tokens ['13', ' minutes', '.']\n",
+ "prob_diff attributed_tokens ['13', ' minutes', '.']\n",
+ "token_displacement attributed_tokens ['13', ' minutes', '.']\n",
"\n",
- "Replaced word 'clock': 0.05s - get_replacement_token()\n",
+ "Replaced word 'Ġclock': 0.24s - get_replacement_token()\n",
"Original: The clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Perturbed: Thework shows 9:47 PM. How many minutes 'til 10?\n",
- "Chat Completion - Perturbed: 0.91s\n",
- "Attributions computation: 0.008873939514160156s\n",
- "- Cosine Attr: 0.0s\n",
- "- Prob Diff Attr: 0.0s\n",
- "- Token Displacement Attr: 0.0s\n",
- "- Attr Logging: 0.0s\n",
- "Total for word 'clock': 0.97s\n",
+ "Perturbed: The Penny shows 9:47 PM. How many minutes 'til 10?\n",
+ "Chat Completion - Perturbed: 4.96s\n",
+ "cosine attributed_tokens ['13', ' minutes', '.']\n",
+ "prob_diff attributed_tokens ['13', ' minutes', '.']\n",
+ "token_displacement attributed_tokens ['13', ' minutes', '.']\n",
"\n",
- "Replaced word 'shows': 0.05s - get_replacement_token()\n",
+ "Replaced word 'Ġshows': 0.36s - get_replacement_token()\n",
"Original: The clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Perturbed: The clock cheesy 9:47 PM. How many minutes 'til 10?\n",
- "Chat Completion - Perturbed: 0.37s\n",
- "Attributions computation: 0.00640106201171875s\n",
- "- Cosine Attr: 0.0s\n",
- "- Prob Diff Attr: 0.0s\n",
- "- Token Displacement Attr: 0.0s\n",
- "- Attr Logging: 0.0s\n",
- "Total for word 'shows': 0.43s\n",
+ "Perturbed: The clockmedia 9:47 PM. How many minutes 'til 10?\n",
+ "Chat Completion - Perturbed: 1.08s\n",
+ "cosine attributed_tokens ['13', ' minutes', '.']\n",
+ "prob_diff attributed_tokens ['13', ' minutes', '.']\n",
+ "token_displacement attributed_tokens ['13', ' minutes', '.']\n",
"\n",
- "Replaced word '9:47': 0.12s - get_replacement_token()\n",
+ "Replaced word 'Ġ9:47': 0.71s - get_replacement_token()\n",
"Original: The clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Perturbed: The clock shows teaspASON unpredict PM. How many minutes 'til 10?\n",
- "Chat Completion - Perturbed: 1.59s\n",
- "Attributions computation: 0.022313833236694336s\n",
- "- Cosine Attr: 0.0s\n",
- "- Prob Diff Attr: 0.0s\n",
- "- Token Displacement Attr: 0.0s\n",
- "- Attr Logging: 0.01s\n",
- "Total for word '9:47': 1.74s\n",
+ "Perturbed: The clock shows accountingASON unpredict PM. How many minutes 'til 10?\n",
+ "Chat Completion - Perturbed: 1.38s\n",
+ "cosine attributed_tokens ['13', ' minutes', '.']\n",
+ "prob_diff attributed_tokens ['13', ' minutes', '.']\n",
+ "token_displacement attributed_tokens ['13', ' minutes', '.']\n",
"\n",
- "Replaced word 'PM.': 0.08s - get_replacement_token()\n",
+ "Replaced word 'ĠPM.': 0.52s - get_replacement_token()\n",
"Original: The clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Perturbed: The clock shows 9:47 bombingroc How many minutes 'til 10?\n",
- "Chat Completion - Perturbed: 0.94s\n",
- "Attributions computation: 0.017086029052734375s\n",
- "- Cosine Attr: 0.0s\n",
- "- Prob Diff Attr: 0.0s\n",
- "- Token Displacement Attr: 0.0s\n",
- "- Attr Logging: 0.0s\n",
- "Total for word 'PM.': 1.04s\n",
+ "Perturbed: The clock shows 9:47 deviationroc How many minutes 'til 10?\n",
+ "Chat Completion - Perturbed: 1.08s\n",
+ "cosine attributed_tokens ['13', ' minutes', '.']\n",
+ "prob_diff attributed_tokens ['13', ' minutes', '.']\n",
+ "token_displacement attributed_tokens ['13', ' minutes', '.']\n",
"\n",
- "Replaced word 'How': 0.04s - get_replacement_token()\n",
+ "Replaced word 'ĠHow': 0.25s - get_replacement_token()\n",
"Original: The clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Perturbed: The clock shows 9:47 PM. empir many minutes 'til 10?\n",
- "Chat Completion - Perturbed: 1.1s\n",
- "Attributions computation: 0.008667945861816406s\n",
- "- Cosine Attr: 0.0s\n",
- "- Prob Diff Attr: 0.0s\n",
- "- Token Displacement Attr: 0.0s\n",
- "- Attr Logging: 0.0s\n",
- "Total for word 'How': 1.15s\n",
+ "Perturbed: The clock shows 9:47 PM. cryptography many minutes 'til 10?\n",
+ "Chat Completion - Perturbed: 1.2s\n",
+ "cosine attributed_tokens ['13', ' minutes', '.']\n",
+ "prob_diff attributed_tokens ['13', ' minutes', '.']\n",
+ "token_displacement attributed_tokens ['13', ' minutes', '.']\n",
"\n",
- "Replaced word 'many': 0.04s - get_replacement_token()\n",
+ "Replaced word 'Ġmany': 0.26s - get_replacement_token()\n",
"Original: The clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Perturbed: The clock shows 9:47 PM. How pubs minutes 'til 10?\n",
- "Chat Completion - Perturbed: 0.9s\n",
- "Attributions computation: 0.010850906372070312s\n",
- "- Cosine Attr: 0.0s\n",
- "- Prob Diff Attr: 0.0s\n",
- "- Token Displacement Attr: 0.0s\n",
- "- Attr Logging: 0.0s\n",
- "Total for word 'many': 0.95s\n",
+ "Perturbed: The clock shows 9:47 PM. How acidic minutes 'til 10?\n",
+ "Chat Completion - Perturbed: 1.11s\n",
+ "cosine attributed_tokens ['13', ' minutes', '.']\n",
+ "prob_diff attributed_tokens ['13', ' minutes', '.']\n",
+ "token_displacement attributed_tokens ['13', ' minutes', '.']\n",
"\n",
- "Replaced word 'minutes': 0.09s - get_replacement_token()\n",
+ "Replaced word 'Ġminutes': 0.29s - get_replacement_token()\n",
"Original: The clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Perturbed: The clock shows 9:47 PM. How manyRespakotil 10?\n",
- "Chat Completion - Perturbed: 0.73s\n",
- "Attributions computation: 0.019688844680786133s\n",
- "- Cosine Attr: 0.0s\n",
- "- Prob Diff Attr: 0.0s\n",
- "- Token Displacement Attr: 0.0s\n",
- "- Attr Logging: 0.0s\n",
- "Total for word 'minutes': 0.84s\n",
+ "Perturbed: The clock shows 9:47 PM. How manystyles 'til 10?\n",
+ "Chat Completion - Perturbed: 0.87s\n",
+ "cosine attributed_tokens ['13', ' minutes', '.']\n",
+ "prob_diff attributed_tokens ['13', ' minutes', '.']\n",
+ "token_displacement attributed_tokens ['13', ' minutes', '.']\n",
"\n",
- "Replaced word ''til': 0.09s - get_replacement_token()\n",
+ "Replaced word 'Ġ'til': 0.55s - get_replacement_token()\n",
"Original: The clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Perturbed: The clock shows 9:47 PM. How many minutes'refugees Naturally?\n",
- "Chat Completion - Perturbed: 0.69s\n",
- "Attributions computation: 0.022047996520996094s\n",
- "- Cosine Attr: 0.0s\n",
- "- Prob Diff Attr: 0.0s\n",
- "- Token Displacement Attr: 0.0s\n",
- "- Attr Logging: 0.0s\n",
- "Total for word ''til': 0.8s\n",
+ "Perturbed: The clock shows 9:47 PM. How many minutes Wah clich 10?\n",
+ "Chat Completion - Perturbed: 0.86s\n",
+ "cosine attributed_tokens ['13', ' minutes', '.']\n",
+ "prob_diff attributed_tokens ['13', ' minutes', '.']\n",
+ "token_displacement attributed_tokens ['13', ' minutes', '.']\n",
"\n",
- "Replaced word '10?': 0.09s - get_replacement_token()\n",
+ "Replaced word 'Ġ10?': 0.48s - get_replacement_token()\n",
"Original: The clock shows 9:47 PM. How many minutes 'til 10?\n",
- "Perturbed: The clock shows 9:47 PM. How many minutes 'til 10 setbacks susceptible\n",
- "Chat Completion - Perturbed: 0.82s\n",
- "Attributions computation: 0.019437074661254883s\n",
- "- Cosine Attr: 0.0s\n",
- "- Prob Diff Attr: 0.0s\n",
- "- Token Displacement Attr: 0.0s\n",
- "- Attr Logging: 0.0s\n",
- "Total for word '10?': 0.93s\n",
+ "Perturbed: The clock shows 9:47 PM. How many minutes 'til lawsuit susceptible\n",
+ "Chat Completion - Perturbed: 0.61s\n",
+ "cosine attributed_tokens ['13', ' minutes', '.']\n",
+ "prob_diff attributed_tokens ['13', ' minutes', '.']\n",
+ "token_displacement attributed_tokens ['13', ' minutes', '.']\n",
"\n",
"\n",
- "Exp Total: 9.65583610534668s\n",
+ "Exp Total: 18.051724910736084s\n",
"\n",
"\n",
"The clock shows 9:47 PM. How many minutes 'til 10? ('13 minutes.',)\n"
@@ -349,8 +406,7 @@
}
],
"source": [
- "input_sequences = [\n",
- " \"The clock shows 9:47 PM. How many minutes 'til 10?\"]\n",
+ "input_texts = [\"The clock shows 9:47 PM. How many minutes 'til 10?\"]\n",
"# \"The building is 132 meters tall. How tall is the building?\",\n",
"# \"The package weighs 8.6 kilograms. How much does the package weigh?\",\n",
"# \"The thermometer reads 23 degrees Celsius. What is the temperature according to the thermometer?\",\n",
@@ -363,26 +419,27 @@
"# ]\n",
"\n",
"\n",
- "for input_sequence in input_sequences:\n",
+ "for input_text in input_texts:\n",
" for perturbation_strategy in [\"distant\"]:\n",
- " original_output = calculate_token_importance_in_sequence(\n",
- " input_sequence,\n",
+ " original_output = calculate_token_importance(\n",
+ " input_text,\n",
" model,\n",
" tokenizer,\n",
" perturbation_strategy,\n",
" attribution_strategies=[\"cosine\", \"prob_diff\", \"token_displacement\"],\n",
" logger=logger,\n",
+ " perturb_word_wise=True,\n",
" )\n",
"\n",
" print(\n",
- " input_sequence,\n",
+ " input_text,\n",
" original_output,\n",
" )"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
@@ -410,6 +467,7 @@
"
original_input | \n",
" original_output | \n",
" perturbation_strategy | \n",
+ " perturb_word_wise | \n",
" duration | \n",
" \n",
" \n",
@@ -420,7 +478,8 @@
" The clock shows 9:47 PM. How many minutes 'til... | \n",
" 13 minutes. | \n",
" distant | \n",
- " 9.658044 | \n",
+ " True | \n",
+ " 18.055734 | \n",
" \n",
" \n",
"\n",
@@ -430,8 +489,8 @@
" exp_id original_input original_output \\\n",
"0 1 The clock shows 9:47 PM. How many minutes 'til... 13 minutes. \n",
"\n",
- " perturbation_strategy duration \n",
- "0 distant 9.658044 "
+ " perturbation_strategy perturb_word_wise duration \n",
+ "0 distant True 18.055734 "
]
},
"metadata": {},
@@ -444,162 +503,159 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
- "\n",
+ "\n",
" \n",
" \n",
" | \n",
- " exp_id | \n",
- " attribution_strategy | \n",
- " perturbation_strategy | \n",
- " token_1 | \n",
- " token_2 | \n",
- " token_3 | \n",
- " token_4 | \n",
- " token_5 | \n",
- " token_6 | \n",
- " token_7 | \n",
- " token_8 | \n",
- " token_9 | \n",
- " token_10 | \n",
- " token_11 | \n",
- " token_12 | \n",
- " token_13 | \n",
- " token_14 | \n",
- " token_15 | \n",
- " token_16 | \n",
+ " exp_id | \n",
+ " attribution_strategy | \n",
+ " perturbation_strategy | \n",
+ " perturb_word_wise | \n",
+ " token_1 | \n",
+ " token_2 | \n",
+ " token_3 | \n",
+ " token_4 | \n",
+ " token_5 | \n",
+ " token_6 | \n",
+ " token_7 | \n",
+ " token_8 | \n",
+ " token_9 | \n",
+ " token_10 | \n",
+ " token_11 | \n",
+ " token_12 | \n",
+ " token_13 | \n",
+ " token_14 | \n",
+ " token_15 | \n",
"
\n",
" \n",
" \n",
" \n",
- " 0 | \n",
- " 1 | \n",
- " cosine | \n",
- " distant | \n",
- " The\n",
- "0.14 | \n",
- " clock\n",
- "0.14 | \n",
- " shows\n",
- "0.10 | \n",
- " 9\n",
- "0.20 | \n",
- " :\n",
- "0.20 | \n",
- " 47\n",
- "0.20 | \n",
- " PM\n",
- "0.17 | \n",
- " .\n",
- "0.17 | \n",
- " How\n",
- "0.14 | \n",
- " many\n",
- "0.14 | \n",
- " min\n",
- "0.14 | \n",
- " utes\n",
- "0.14 | \n",
- " 't\n",
- "0.10 | \n",
- " il\n",
- "0.10 | \n",
- " 10\n",
+ " | 0 | \n",
+ " 1 | \n",
+ " cosine | \n",
+ " distant | \n",
+ " True | \n",
+ " The\n",
"0.00 | \n",
- " ?\n",
+ " | clock\n",
"0.00 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 1 | \n",
- " prob_diff | \n",
- " distant | \n",
- " The\n",
- "0.76 | \n",
- " clock\n",
- "0.70 | \n",
- " shows\n",
- "0.28 | \n",
- " 9\n",
- "0.80 | \n",
- " :\n",
- "0.80 | \n",
- " 47\n",
- "0.80 | \n",
- " PM\n",
- "0.78 | \n",
- " .\n",
- "0.78 | \n",
- " How\n",
- "0.79 | \n",
- " many\n",
- "0.72 | \n",
- " min\n",
- "0.80 | \n",
- " utes\n",
+ " | shows\n",
+ "0.00 | \n",
+ " 9\n",
+ "0.15 | \n",
+ " :\n",
+ "0.15 | \n",
+ " 47\n",
+ "0.15 | \n",
+ " PM\n",
+ "0.15 | \n",
+ " .\n",
+ "0.15 | \n",
+ " How\n",
+ "0.13 | \n",
+ " many\n",
+ "0.13 | \n",
+ " minutes\n",
+ "0.13 | \n",
+ " '\n",
+ "0.00 | \n",
+ " til\n",
+ "0.00 | \n",
+ " 10\n",
+ "0.11 | \n",
+ " ?\n",
+ "0.11 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " prob_diff | \n",
+ " distant | \n",
+ " True | \n",
+ " The\n",
+ "0.08 | \n",
+ " clock\n",
+ "0.12 | \n",
+ " shows\n",
+ "0.11 | \n",
+ " 9\n",
+ "0.82 | \n",
+ " :\n",
+ "0.82 | \n",
+ " 47\n",
+ "0.82 | \n",
+ " PM\n",
+ "0.81 | \n",
+ " .\n",
+ "0.81 | \n",
+ " How\n",
+ "0.81 | \n",
+ " many\n",
"0.80 | \n",
- " 't\n",
- "0.29 | \n",
- " il\n",
- "0.29 | \n",
- " 10\n",
- "0.10 | \n",
- " ?\n",
- "0.10 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 1 | \n",
- " token_displacement | \n",
- " distant | \n",
- " The\n",
+ " | minutes\n",
+ "0.82 | \n",
+ " '\n",
+ "0.05 | \n",
+ " til\n",
+ "0.05 | \n",
+ " 10\n",
+ "0.24 | \n",
+ " ?\n",
+ "0.24 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1 | \n",
+ " token_displacement | \n",
+ " distant | \n",
+ " True | \n",
+ " The\n",
+ "0.00 | \n",
+ " clock\n",
+ "0.00 | \n",
+ " shows\n",
+ "0.00 | \n",
+ " 9\n",
+ "19.33 | \n",
+ " :\n",
+ "19.33 | \n",
+ " 47\n",
+ "19.33 | \n",
+ " PM\n",
+ "12.67 | \n",
+ " .\n",
+ "12.67 | \n",
+ " How\n",
"13.67 | \n",
- " clock\n",
- "13.00 | \n",
- " shows\n",
- "6.67 | \n",
- " 9\n",
- "20.00 | \n",
- " :\n",
- "20.00 | \n",
- " 47\n",
- "20.00 | \n",
- " PM\n",
- "13.33 | \n",
- " .\n",
- "13.33 | \n",
- " How\n",
+ " | many\n",
"13.67 | \n",
- " many\n",
+ " | minutes\n",
"13.67 | \n",
- " min\n",
- "16.67 | \n",
- " utes\n",
- "16.67 | \n",
- " 't\n",
- "6.67 | \n",
- " il\n",
- "6.67 | \n",
- " 10\n",
+ " | '\n",
"0.00 | \n",
- " ?\n",
+ " | til\n",
"0.00 | \n",
+ " 10\n",
+ "6.67 | \n",
+ " ?\n",
+ "6.67 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
- ""
+ ""
]
},
"metadata": {},
@@ -612,7 +668,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
@@ -627,134 +683,128 @@
"data": {
"text/html": [
"\n",
- "\n",
+ "\n",
" \n",
" \n",
" | \n",
- " The (0) | \n",
- " clock (1) | \n",
- " shows (2) | \n",
+ " 13 (0) | \n",
+ " minutes (1) | \n",
+ " . (2) | \n",
"
\n",
" \n",
" \n",
" \n",
- " The (0) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.701984 | \n",
+ " The (0) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " clock (1) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.701984 | \n",
+ " clock (1) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " shows (2) | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 1.000000 | \n",
+ " shows (2) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " 9 (3) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.747480 | \n",
+ " 9 (3) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.747480 | \n",
"
\n",
" \n",
- " : (4) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.747480 | \n",
+ " : (4) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.747480 | \n",
"
\n",
" \n",
- " 47 (5) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.747480 | \n",
+ " 47 (5) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.747480 | \n",
"
\n",
" \n",
- " PM (6) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.701984 | \n",
+ " PM (6) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.701984 | \n",
"
\n",
" \n",
- " . (7) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.701984 | \n",
+ " . (7) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.701984 | \n",
"
\n",
" \n",
- " How (8) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.701984 | \n",
+ " How (8) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.701984 | \n",
"
\n",
" \n",
- " many (9) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.701984 | \n",
+ " many (9) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.701984 | \n",
"
\n",
" \n",
- " min (10) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.701984 | \n",
+ " minutes (10) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.701984 | \n",
"
\n",
" \n",
- " utes (11) | \n",
- " 0.600346 | \n",
- " 0.759586 | \n",
- " 0.701984 | \n",
+ " ' (11) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " 't (12) | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 1.000000 | \n",
+ " til (12) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " il (13) | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 1.000000 | \n",
+ " 10 (13) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 1.000000 | \n",
"
\n",
" \n",
- " 10 (14) | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " ? (15) | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
+ " ? (14) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 1.000000 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
- ""
+ ""
]
},
"metadata": {},
@@ -772,154 +822,160 @@
"data": {
"text/html": [
"\n",
- "\n",
+ "\n",
" \n",
" \n",
" | \n",
- " 13 (0) | \n",
- " minutes (1) | \n",
- " . (2) | \n",
+ " 13 (0) | \n",
+ " minutes (1) | \n",
+ " . (2) | \n",
"
\n",
" \n",
" \n",
" \n",
- " The (0) | \n",
- " 0.646714 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
- "
\n",
- " \n",
- " clock (1) | \n",
- " 0.489712 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
+ " The (0) | \n",
+ " 0.239891 | \n",
+ " 0.000012 | \n",
+ " 0.004187 | \n",
"
\n",
" \n",
- " shows (2) | \n",
- " 0.231651 | \n",
- " 0.000033 | \n",
- " 0.619983 | \n",
+ " clock (1) | \n",
+ " 0.339138 | \n",
+ " 0.000061 | \n",
+ " 0.025875 | \n",
"
\n",
" \n",
- " 9 (3) | \n",
- " 0.785889 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
+ " shows (2) | \n",
+ " 0.299949 | \n",
+ " 0.000029 | \n",
+ " 0.024715 | \n",
"
\n",
" \n",
- " : (4) | \n",
- " 0.785889 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
+ " 9 (3) | \n",
+ " 0.817992 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
- " 47 (5) | \n",
- " 0.785889 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
+ " : (4) | \n",
+ " 0.817992 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
- " PM (6) | \n",
- " 0.707377 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
+ " 47 (5) | \n",
+ " 0.817992 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
- " . (7) | \n",
- " 0.707377 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
+ " PM (6) | \n",
+ " 0.804387 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
- " How (8) | \n",
- " 0.754307 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
+ " . (7) | \n",
+ " 0.804387 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
- " many (9) | \n",
- " 0.535038 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
+ " How (8) | \n",
+ " 0.787602 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
- " min (10) | \n",
- " 0.785592 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
+ " many (9) | \n",
+ " 0.758278 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
- " utes (11) | \n",
- " 0.785592 | \n",
- " 0.999919 | \n",
- " 0.619983 | \n",
+ " minutes (10) | \n",
+ " 0.814092 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
- " 't (12) | \n",
- " 0.257238 | \n",
- " 0.000704 | \n",
- " 0.619983 | \n",
+ " ' (11) | \n",
+ " 0.042329 | \n",
+ " 0.001205 | \n",
+ " 0.102655 | \n",
"
\n",
" \n",
- " il (13) | \n",
- " 0.257238 | \n",
- " 0.000704 | \n",
- " 0.619983 | \n",
+ " til (12) | \n",
+ " 0.042329 | \n",
+ " 0.001205 | \n",
+ " 0.102655 | \n",
"
\n",
" \n",
- " 10 (14) | \n",
- " 0.113566 | \n",
- " 0.000129 | \n",
- " 0.194634 | \n",
+ " 10 (13) | \n",
+ " 0.082409 | \n",
+ " 0.000164 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
- " ? (15) | \n",
- " 0.113566 | \n",
- " 0.000129 | \n",
- " 0.194634 | \n",
+ " ? (14) | \n",
+ " 0.082409 | \n",
+ " 0.000164 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
- ""
+ ""
]
},
"metadata": {},
@@ -937,142 +993,132 @@
"data": {
"text/html": [
"\n",
- "\n",
+ "\n",
" \n",
" \n",
" | \n",
- " 13 (0) | \n",
- " minutes (1) | \n",
- " . (2) | \n",
+ " 13 (0) | \n",
+ " minutes (1) | \n",
+ " . (2) | \n",
"
\n",
" \n",
" \n",
" \n",
- " The (0) | \n",
- " 1.000000 | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
+ " The (0) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " clock (1) | \n",
- " 1.000000 | \n",
- " 18.000000 | \n",
- " 20.000000 | \n",
+ " clock (1) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " shows (2) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 20.000000 | \n",
+ " shows (2) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " 9 (3) | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
+ " 9 (3) | \n",
+ " 20.000000 | \n",
+ " 18.000000 | \n",
+ " 20.000000 | \n",
"
\n",
" \n",
- " : (4) | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
+ " : (4) | \n",
+ " 20.000000 | \n",
+ " 18.000000 | \n",
+ " 20.000000 | \n",
"
\n",
" \n",
- " 47 (5) | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
+ " 47 (5) | \n",
+ " 20.000000 | \n",
+ " 18.000000 | \n",
+ " 20.000000 | \n",
"
\n",
" \n",
- " PM (6) | \n",
- " 1.000000 | \n",
- " 19.000000 | \n",
- " 20.000000 | \n",
+ " PM (6) | \n",
+ " 1.000000 | \n",
+ " 17.000000 | \n",
+ " 20.000000 | \n",
"
\n",
" \n",
- " . (7) | \n",
- " 1.000000 | \n",
- " 19.000000 | \n",
- " 20.000000 | \n",
+ " . (7) | \n",
+ " 1.000000 | \n",
+ " 17.000000 | \n",
+ " 20.000000 | \n",
"
\n",
" \n",
- " How (8) | \n",
- " 1.000000 | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
+ " How (8) | \n",
+ " 1.000000 | \n",
+ " 20.000000 | \n",
+ " 20.000000 | \n",
"
\n",
" \n",
- " many (9) | \n",
- " 1.000000 | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
+ " many (9) | \n",
+ " 1.000000 | \n",
+ " 20.000000 | \n",
+ " 20.000000 | \n",
"
\n",
" \n",
- " min (10) | \n",
- " 10.000000 | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
+ " minutes (10) | \n",
+ " 1.000000 | \n",
+ " 20.000000 | \n",
+ " 20.000000 | \n",
"
\n",
" \n",
- " utes (11) | \n",
- " 10.000000 | \n",
- " 20.000000 | \n",
- " 20.000000 | \n",
+ " ' (11) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " 't (12) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 20.000000 | \n",
+ " til (12) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
" \n",
- " il (13) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 20.000000 | \n",
+ " 10 (13) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 20.000000 | \n",
"
\n",
" \n",
- " 10 (14) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " ? (15) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
+ " ? (14) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 20.000000 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
- ""
+ ""
]
},
"metadata": {},
@@ -1085,14 +1131,14 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Attribution matrix for cosine with perturbation strategy distant:\n",
+ "Attribution matrix for prob_diff with perturbation strategy distant:\n",
"Input Tokens (Rows) vs. Output Tokens (Columns)\n"
]
},
@@ -1100,240 +1146,160 @@
"data": {
"text/html": [
"\n",
- "\n",
+ "\n",
" \n",
" \n",
" | \n",
- " The (0) | \n",
- " clock (1) | \n",
- " shows (2) | \n",
- " 9 (3) | \n",
- " : (4) | \n",
- " 47 (5) | \n",
- " PM (6) | \n",
- " . (7) | \n",
+ " 13 (0) | \n",
+ " minutes (1) | \n",
+ " . (2) | \n",
"
\n",
" \n",
" \n",
" \n",
- " The (0) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " clock (1) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " shows (2) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " 9 (3) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.289177 | \n",
- " 0.000000 | \n",
- " 0.274936 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " : (4) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.289177 | \n",
- " 0.000000 | \n",
- " 0.274936 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " 47 (5) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.289177 | \n",
- " 0.000000 | \n",
- " 0.274936 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " PM (6) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.821500 | \n",
- " 1.000000 | \n",
- "
\n",
- " \n",
- " . (7) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.821500 | \n",
- " 1.000000 | \n",
- "
\n",
- " \n",
- " What (8) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " time (9) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.488054 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.821500 | \n",
- " 1.000000 | \n",
- "
\n",
- " \n",
- " does (10) | \n",
- " 0.560580 | \n",
- " 0.856970 | \n",
- " 0.696821 | \n",
- " 0.746326 | \n",
- " 1.000000 | \n",
- " 1.000000 | \n",
- " 1.000000 | \n",
- " 1.000000 | \n",
- "
\n",
- " \n",
- " the (11) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " clock (12) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " show (13) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.828613 | \n",
- " 0.754033 | \n",
- " 0.513017 | \n",
- " 0.693778 | \n",
- " 0.781776 | \n",
- " 0.758650 | \n",
- "
\n",
- " \n",
- " ? (14) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.828613 | \n",
- " 0.754033 | \n",
- " 0.513017 | \n",
- " 0.693778 | \n",
- " 0.781776 | \n",
- " 0.758650 | \n",
+ " The (0) | \n",
+ " 0.239891 | \n",
+ " 0.000012 | \n",
+ " 0.004187 | \n",
+ "
\n",
+ " \n",
+ " clock (1) | \n",
+ " 0.339138 | \n",
+ " 0.000061 | \n",
+ " 0.025875 | \n",
+ "
\n",
+ " \n",
+ " shows (2) | \n",
+ " 0.299949 | \n",
+ " 0.000029 | \n",
+ " 0.024715 | \n",
+ "
\n",
+ " \n",
+ " 9 (3) | \n",
+ " 0.817992 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
+ "
\n",
+ " \n",
+ " : (4) | \n",
+ " 0.817992 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
+ "
\n",
+ " \n",
+ " 47 (5) | \n",
+ " 0.817992 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
+ "
\n",
+ " \n",
+ " PM (6) | \n",
+ " 0.804387 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
+ "
\n",
+ " \n",
+ " . (7) | \n",
+ " 0.804387 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
+ "
\n",
+ " \n",
+ " How (8) | \n",
+ " 0.787602 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
+ "
\n",
+ " \n",
+ " many (9) | \n",
+ " 0.758278 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
+ "
\n",
+ " \n",
+ " minutes (10) | \n",
+ " 0.814092 | \n",
+ " 0.999872 | \n",
+ " 0.640487 | \n",
+ "
\n",
+ " \n",
+ " ' (11) | \n",
+ " 0.042329 | \n",
+ " 0.001205 | \n",
+ " 0.102655 | \n",
+ "
\n",
+ " \n",
+ " til (12) | \n",
+ " 0.042329 | \n",
+ " 0.001205 | \n",
+ " 0.102655 | \n",
+ "
\n",
+ " \n",
+ " 10 (13) | \n",
+ " 0.082409 | \n",
+ " 0.000164 | \n",
+ " 0.640487 | \n",
+ "
\n",
+ " \n",
+ " ? (14) | \n",
+ " 0.082409 | \n",
+ " 0.000164 | \n",
+ " 0.640487 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
- ""
+ ""
]
},
"metadata": {},
@@ -1341,12 +1307,12 @@
}
],
"source": [
- "logger.print_attribution_matrix(1, \"cosine\")"
+ "logger.print_attribution_matrix(exp_id=1, attribution_strategy=\"prob_diff\")"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
@@ -1361,583 +1327,128 @@
"data": {
"text/html": [
"\n",
- "\n",
+ "\n",
" \n",
" \n",
" | \n",
- " The (0) | \n",
- " clock (1) | \n",
- " shows (2) | \n",
- " 9 (3) | \n",
- " : (4) | \n",
- " 47 (5) | \n",
- " PM (6) | \n",
- " . (7) | \n",
+ " 13 (0) | \n",
+ " minutes (1) | \n",
+ " . (2) | \n",
"
\n",
" \n",
" \n",
" \n",
- " The (0) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " clock (1) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " shows (2) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " 9 (3) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.289177 | \n",
- " 0.000000 | \n",
- " 0.274936 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " : (4) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.289177 | \n",
- " 0.000000 | \n",
- " 0.274936 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " 47 (5) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.289177 | \n",
- " 0.000000 | \n",
- " 0.274936 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " PM (6) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.821500 | \n",
- " 1.000000 | \n",
- "
\n",
- " \n",
- " . (7) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.821500 | \n",
- " 1.000000 | \n",
- "
\n",
- " \n",
- " What (8) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " time (9) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.488054 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.821500 | \n",
- " 1.000000 | \n",
- "
\n",
- " \n",
- " does (10) | \n",
- " 0.560580 | \n",
- " 0.856970 | \n",
- " 0.696821 | \n",
- " 0.746326 | \n",
- " 1.000000 | \n",
- " 1.000000 | \n",
- " 1.000000 | \n",
- " 1.000000 | \n",
- "
\n",
- " \n",
- " the (11) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " clock (12) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " show (13) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.828613 | \n",
- " 0.754033 | \n",
- " 0.513017 | \n",
- " 0.693778 | \n",
- " 0.781776 | \n",
- " 0.758650 | \n",
- "
\n",
- " \n",
- " ? (14) | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.828613 | \n",
- " 0.754033 | \n",
- " 0.513017 | \n",
- " 0.693778 | \n",
- " 0.781776 | \n",
- " 0.758650 | \n",
+ " The (0) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
"
\n",
- " \n",
- "
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Attribution matrix for prob_diff with perturbation strategy distant:\n",
- "Input Tokens (Rows) vs. Output Tokens (Columns)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
" \n",
- " | \n",
- " The (0) | \n",
- " clock (1) | \n",
- " shows (2) | \n",
- " (3) | \n",
- " 9 (4) | \n",
- " : (5) | \n",
- " 47 (6) | \n",
- " PM (7) | \n",
- " . (8) | \n",
+ " clock (1) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " shows (2) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " 9 (3) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.747480 | \n",
+ "
\n",
+ " \n",
+ " : (4) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.747480 | \n",
+ "
\n",
+ " \n",
+ " 47 (5) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.747480 | \n",
+ "
\n",
+ " \n",
+ " PM (6) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.701984 | \n",
"
\n",
- " \n",
- " \n",
" \n",
- " The (0) | \n",
- " 0.064448 | \n",
- " 0.007869 | \n",
- " 0.003085 | \n",
- " 0.002989 | \n",
- " 0.000108 | \n",
- " 0.000128 | \n",
- " 0.000118 | \n",
- " 0.000054 | \n",
- " 0.007132 | \n",
- "
\n",
- " \n",
- " clock (1) | \n",
- " 0.136458 | \n",
- " 0.033954 | \n",
- " 0.003810 | \n",
- " 0.002736 | \n",
- " 0.000285 | \n",
- " 0.000101 | \n",
- " 0.001414 | \n",
- " 0.000432 | \n",
- " 0.000383 | \n",
- "
\n",
- " \n",
- " shows (2) | \n",
- " 0.151668 | \n",
- " 0.006804 | \n",
- " 0.009290 | \n",
- " 0.006151 | \n",
- " 0.000079 | \n",
- " 0.000141 | \n",
- " 0.000733 | \n",
- " 0.000125 | \n",
- " 0.016142 | \n",
- "
\n",
- " \n",
- " 9 (3) | \n",
- " 0.191758 | \n",
- " 0.008194 | \n",
- " 0.016663 | \n",
- " 0.040286 | \n",
- " 0.951041 | \n",
- " 0.000315 | \n",
- " 0.958258 | \n",
- " 0.000138 | \n",
- " 0.001367 | \n",
- "
\n",
- " \n",
- " : (4) | \n",
- " 0.191758 | \n",
- " 0.008194 | \n",
- " 0.016663 | \n",
- " 0.040286 | \n",
- " 0.951041 | \n",
- " 0.000315 | \n",
- " 0.958258 | \n",
- " 0.000138 | \n",
- " 0.001367 | \n",
- "
\n",
- " \n",
- " 47 (5) | \n",
- " 0.191758 | \n",
- " 0.008194 | \n",
- " 0.016663 | \n",
- " 0.040286 | \n",
- " 0.951041 | \n",
- " 0.000315 | \n",
- " 0.958258 | \n",
- " 0.000138 | \n",
- " 0.001367 | \n",
- "
\n",
- " \n",
- " PM (6) | \n",
- " 0.171069 | \n",
- " 0.000117 | \n",
- " 0.001073 | \n",
- " 0.000121 | \n",
- " 0.000113 | \n",
- " 0.000139 | \n",
- " 0.000016 | \n",
- " 0.999858 | \n",
- " 0.994526 | \n",
- "
\n",
- " \n",
- " . (7) | \n",
- " 0.171069 | \n",
- " 0.000117 | \n",
- " 0.001073 | \n",
- " 0.000121 | \n",
- " 0.000113 | \n",
- " 0.000139 | \n",
- " 0.000016 | \n",
- " 0.999858 | \n",
- " 0.994526 | \n",
- "
\n",
- " \n",
- " What (8) | \n",
- " 0.198133 | \n",
- " 0.007814 | \n",
- " 0.148431 | \n",
- " 0.018417 | \n",
- " 0.003066 | \n",
- " 0.000065 | \n",
- " 0.003120 | \n",
- " 0.001621 | \n",
- " 0.126054 | \n",
- "
\n",
- " \n",
- " time (9) | \n",
- " 0.196840 | \n",
- " 0.005341 | \n",
- " 0.009895 | \n",
- " 0.077305 | \n",
- " 0.884765 | \n",
- " 0.000052 | \n",
- " 0.000132 | \n",
- " 0.925403 | \n",
- " 0.994526 | \n",
- "
\n",
- " \n",
- " does (10) | \n",
- " 0.306942 | \n",
- " 0.989326 | \n",
- " 0.997537 | \n",
- " 0.999354 | \n",
- " 0.999670 | \n",
- " 0.999856 | \n",
- " 0.999899 | \n",
- " 0.999896 | \n",
- " 0.994526 | \n",
- "
\n",
- " \n",
- " the (11) | \n",
- " 0.079202 | \n",
- " 0.481994 | \n",
- " 0.116743 | \n",
- " 0.195275 | \n",
- " 0.008685 | \n",
- " 0.000106 | \n",
- " 0.001967 | \n",
- " 0.005189 | \n",
- " 0.018881 | \n",
- "
\n",
- " \n",
- " clock (12) | \n",
- " 0.132234 | \n",
- " 0.082222 | \n",
- " 0.065567 | \n",
- " 0.025915 | \n",
- " 0.004124 | \n",
- " 0.000114 | \n",
- " 0.000178 | \n",
- " 0.000570 | \n",
- " 0.028210 | \n",
- "
\n",
- " \n",
- " show (13) | \n",
- " 0.006805 | \n",
- " 0.014062 | \n",
- " 0.980760 | \n",
- " 0.999361 | \n",
- " 0.999670 | \n",
- " 0.999856 | \n",
- " 0.997756 | \n",
- " 0.999865 | \n",
- " 0.994526 | \n",
- "
\n",
- " \n",
- " ? (14) | \n",
- " 0.006805 | \n",
- " 0.014062 | \n",
- " 0.980760 | \n",
- " 0.999361 | \n",
- " 0.999670 | \n",
- " 0.999856 | \n",
- " 0.997756 | \n",
- " 0.999865 | \n",
- " 0.994526 | \n",
+ " . (7) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.701984 | \n",
+ "
\n",
+ " \n",
+ " How (8) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.701984 | \n",
+ "
\n",
+ " \n",
+ " many (9) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.701984 | \n",
+ "
\n",
+ " \n",
+ " minutes (10) | \n",
+ " 0.688956 | \n",
+ " 0.759586 | \n",
+ " 0.701984 | \n",
+ "
\n",
+ " \n",
+ " ' (11) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " til (12) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " 10 (13) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 1.000000 | \n",
+ "
\n",
+ " \n",
+ " ? (14) | \n",
+ " 0.000000 | \n",
+ " 0.000000 | \n",
+ " 1.000000 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
- ""
+ ""
]
},
"metadata": {},
@@ -1945,530 +1456,39 @@
}
],
"source": [
- "logger.print_attribution_matrix(2, attribution_strategy=\"cosine\")\n",
- "logger.print_attribution_matrix(2, attribution_strategy=\"prob_diff\")"
+ "logger.print_attribution_matrix(1, \"cosine\")"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Attribution matrix for cosine with perturbation strategy distant:\n",
- "Input Tokens (Rows) vs. Output Tokens (Columns)\n"
+ "ename": "IndexError",
+ "evalue": "index 0 is out of bounds for axis 0 with size 0",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mlogger\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprint_attribution_matrix\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattribution_strategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcosine\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m logger\u001b[38;5;241m.\u001b[39mprint_attribution_matrix(\u001b[38;5;241m2\u001b[39m, attribution_strategy\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprob_diff\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
+ "File \u001b[0;32m~/Projects/llm-attribution/attribution/experiment_logger.py:184\u001b[0m, in \u001b[0;36mExperimentLogger.print_attribution_matrix\u001b[0;34m(self, exp_id, attribution_strategy)\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 176\u001b[0m \u001b[38;5;66;03m# Filter the data for the specific experiment and attribution strategy\u001b[39;00m\n\u001b[1;32m 177\u001b[0m exp_data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdf_token_attribution_matrix[\n\u001b[1;32m 178\u001b[0m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdf_token_attribution_matrix[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexp_id\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m==\u001b[39m exp_id)\n\u001b[1;32m 179\u001b[0m \u001b[38;5;241m&\u001b[39m (\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 182\u001b[0m )\n\u001b[1;32m 183\u001b[0m ]\n\u001b[0;32m--> 184\u001b[0m perturbation_strategy \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdf_experiments\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloc\u001b[49m\u001b[43m[\u001b[49m\n\u001b[1;32m 185\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdf_experiments\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mexp_id\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mexp_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mperturbation_strategy\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[1;32m 186\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# Create the pivot table for the matrix\u001b[39;00m\n\u001b[1;32m 189\u001b[0m matrix \u001b[38;5;241m=\u001b[39m exp_data\u001b[38;5;241m.\u001b[39mpivot(\n\u001b[1;32m 190\u001b[0m index\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_token_pos\u001b[39m\u001b[38;5;124m\"\u001b[39m, columns\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_token_pos\u001b[39m\u001b[38;5;124m\"\u001b[39m, values\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattr_score\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 191\u001b[0m )\n",
+ "\u001b[0;31mIndexError\u001b[0m: index 0 is out of bounds for axis 0 with size 0"
]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " | \n",
- " The (0) | \n",
- " building (1) | \n",
- " is (2) | \n",
- " 132 (3) | \n",
- " meters (4) | \n",
- " tall (5) | \n",
- " . (6) | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " The (0) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " building (1) | \n",
- " 0.226659 | \n",
- " 0.705013 | \n",
- " 0.520012 | \n",
- " 0.806286 | \n",
- " 0.855425 | \n",
- " 0.756068 | \n",
- " 0.466330 | \n",
- "
\n",
- " \n",
- " is (2) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " 132 (3) | \n",
- " 0.226659 | \n",
- " 0.770341 | \n",
- " 0.621099 | \n",
- " 0.807823 | \n",
- " 0.785745 | \n",
- " 0.767478 | \n",
- " 0.480238 | \n",
- "
\n",
- " \n",
- " met (4) | \n",
- " 0.226659 | \n",
- " 0.705013 | \n",
- " 0.520012 | \n",
- " 0.788010 | \n",
- " 0.798989 | \n",
- " 0.771913 | \n",
- " 0.412676 | \n",
- "
\n",
- " \n",
- " ers (5) | \n",
- " 0.226659 | \n",
- " 0.705013 | \n",
- " 0.520012 | \n",
- " 0.788010 | \n",
- " 0.798989 | \n",
- " 0.771913 | \n",
- " 0.412676 | \n",
- "
\n",
- " \n",
- " tall (6) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " . (7) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " How (8) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " tall (9) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " is (10) | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " -0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- " 0.000000 | \n",
- "
\n",
- " \n",
- " the (11) | \n",
- " 0.000000 | \n",
- " 0.758592 | \n",
- " 0.410312 | \n",
- " 0.698263 | \n",
- " 0.818360 | \n",
- " 0.803358 | \n",
- " 0.749395 | \n",
- "
\n",
- " \n",
- " building (12) | \n",
- " 0.226659 | \n",
- " 0.705013 | \n",
- " 0.520012 | \n",
- " 0.788010 | \n",
- " 0.798989 | \n",
- " 0.778579 | \n",
- " 0.796747 | \n",
- "
\n",
- " \n",
- " ? (13) | \n",
- " 0.226659 | \n",
- " 0.705013 | \n",
- " 0.520012 | \n",
- " 0.788010 | \n",
- " 0.798989 | \n",
- " 0.778579 | \n",
- " 0.796747 | \n",
- "
\n",
- " \n",
- "
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Attribution matrix for prob_diff with perturbation strategy distant:\n",
- "Input Tokens (Rows) vs. Output Tokens (Columns)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " \n",
- " \n",
- " | \n",
- " The (0) | \n",
- " building (1) | \n",
- " is (2) | \n",
- " (3) | \n",
- " 132 (4) | \n",
- " meters (5) | \n",
- " tall (6) | \n",
- " . (7) | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " The (0) | \n",
- " 0.018736 | \n",
- " 0.000113 | \n",
- " 0.000309 | \n",
- " 0.000079 | \n",
- " 0.000053 | \n",
- " 0.000006 | \n",
- " 0.000121 | \n",
- " 0.000022 | \n",
- "
\n",
- " \n",
- " building (1) | \n",
- " 0.777744 | \n",
- " 0.999023 | \n",
- " 0.999883 | \n",
- " 0.999955 | \n",
- " 0.999962 | \n",
- " 0.999989 | \n",
- " 0.999822 | \n",
- " 0.999864 | \n",
- "
\n",
- " \n",
- " is (2) | \n",
- " 0.015364 | \n",
- " 0.000713 | \n",
- " 0.008316 | \n",
- " 0.000166 | \n",
- " 0.000016 | \n",
- " 0.000007 | \n",
- " 0.000146 | \n",
- " 0.000301 | \n",
- "
\n",
- " \n",
- " 132 (3) | \n",
- " 0.985685 | \n",
- " 0.999023 | \n",
- " 0.999883 | \n",
- " 0.999893 | \n",
- " 0.999962 | \n",
- " 0.999989 | \n",
- " 0.999822 | \n",
- " 0.999864 | \n",
- "
\n",
- " \n",
- " met (4) | \n",
- " 0.963465 | \n",
- " 0.999023 | \n",
- " 0.999883 | \n",
- " 0.999942 | \n",
- " 0.999962 | \n",
- " 0.999989 | \n",
- " 0.999822 | \n",
- " 0.999864 | \n",
- "
\n",
- " \n",
- " ers (5) | \n",
- " 0.963465 | \n",
- " 0.999023 | \n",
- " 0.999883 | \n",
- " 0.999942 | \n",
- " 0.999962 | \n",
- " 0.999989 | \n",
- " 0.999822 | \n",
- " 0.999864 | \n",
- "
\n",
- " \n",
- " tall (6) | \n",
- " 0.057443 | \n",
- " 0.000039 | \n",
- " 0.000001 | \n",
- " 0.001233 | \n",
- " 0.000211 | \n",
- " 0.000013 | \n",
- " 0.001011 | \n",
- " 0.000231 | \n",
- "
\n",
- " \n",
- " . (7) | \n",
- " 0.057443 | \n",
- " 0.000039 | \n",
- " 0.000001 | \n",
- " 0.001233 | \n",
- " 0.000211 | \n",
- " 0.000013 | \n",
- " 0.001011 | \n",
- " 0.000231 | \n",
- "
\n",
- " \n",
- " How (8) | \n",
- " 0.017349 | \n",
- " 0.013034 | \n",
- " 0.006253 | \n",
- " 0.351449 | \n",
- " 0.024528 | \n",
- " 0.000197 | \n",
- " 0.034164 | \n",
- " 0.228418 | \n",
- "
\n",
- " \n",
- " tall (9) | \n",
- " 0.028325 | \n",
- " 0.024104 | \n",
- " 0.002190 | \n",
- " 0.004585 | \n",
- " 0.000051 | \n",
- " 0.000016 | \n",
- " 0.000100 | \n",
- " 0.001761 | \n",
- "
\n",
- " \n",
- " is (10) | \n",
- " 0.303973 | \n",
- " 0.439957 | \n",
- " 0.009363 | \n",
- " 0.001068 | \n",
- " 0.000115 | \n",
- " 0.000057 | \n",
- " 0.000024 | \n",
- " 0.001697 | \n",
- "
\n",
- " \n",
- " the (11) | \n",
- " 0.455051 | \n",
- " 0.987737 | \n",
- " 0.795123 | \n",
- " 0.999935 | \n",
- " 0.999962 | \n",
- " 0.999989 | \n",
- " 0.999822 | \n",
- " 0.989866 | \n",
- "
\n",
- " \n",
- " building (12) | \n",
- " 0.972775 | \n",
- " 0.999023 | \n",
- " 0.999883 | \n",
- " 0.999955 | \n",
- " 0.999962 | \n",
- " 0.999989 | \n",
- " 0.999822 | \n",
- " 0.999864 | \n",
- "
\n",
- " \n",
- " ? (13) | \n",
- " 0.972775 | \n",
- " 0.999023 | \n",
- " 0.999883 | \n",
- " 0.999955 | \n",
- " 0.999962 | \n",
- " 0.999989 | \n",
- " 0.999822 | \n",
- " 0.999864 | \n",
- "
\n",
- " \n",
- "
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
}
],
"source": [
- "logger.print_attribution_matrix(3, attribution_strategy=\"cosine\")\n",
- "logger.print_attribution_matrix(3, attribution_strategy=\"prob_diff\")"
+ "logger.print_attribution_matrix(2, attribution_strategy=\"cosine\")\n",
+ "logger.print_attribution_matrix(2, attribution_strategy=\"prob_diff\")"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -2493,7 +1513,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.6"
+ "version": "3.9.19"
}
},
"nbformat": 4,