Skip to content

Commit

Permalink
TCT PRF notebook #29
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacdonald authored Nov 26, 2024
1 parent 081db58 commit be6a3cf
Showing 1 changed file with 396 additions and 0 deletions.
396 changes: 396 additions & 0 deletions examples/tct-prf.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,396 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "7e2d65ff-eea0-44bd-a736-9675ae28ad9c",
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade git+https://github.com/terrierteam/pyterrier_dr.git\n",
"import pyterrier as pt\n",
"pt.utils.set_tqdm('notebook')\n",
"import pyterrier_dr"
]
},
{
"cell_type": "markdown",
"id": "bd1ae81d-21fe-49d1-b282-a06bc6a53ce7",
"metadata": {},
"source": [
"Load your index and the relevant model\n",
"\n",
"If you dont have an existing TCT index, one can be created for MSMARCO as follows:\n",
"```python\n",
"index = pyterrier_dr.FlexIndex('myindex.flex')\n",
"idx_pipeline = model >> index\n",
"idx_pipeline.index(pt.get_dataset('irds:msmarco-passage').get_corpus_iter())\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9dc40df5-9898-4888-a2c1-72c27f3d7bb7",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8d57b49b48444edcad1957c10465e3db",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/334 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb0c8d86392843089c0f28fafd0574b9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/559 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4772adacafa14c6194569fb48358c0a7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "30fc505f140d482baceff1e46ba0630e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"special_tokens_map.json: 0%| | 0.00/112 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "481d4491e9bb42d381244c3130254df1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"pytorch_model.bin: 0%| | 0.00/438M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" return self.fget.__get__(instance, owner)()\n"
]
}
],
"source": [
"index_loc = '/nfs/global_indices/msmarco-passage.tct-hnp.flex'\n",
"model = pyterrier_dr.TctColBert('castorini/tct_colbert-v2-hnp-msmarco')\n",
"index = pyterrier_dr.FlexIndex(index_loc)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "46556e84-b895-47b4-b439-8016cb4081cd",
"metadata": {},
"outputs": [],
"source": [
"# standard dense retrieval\n",
"baseline = model >> index"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "75d84842-93c8-455a-9c96-e1d74f166298",
"metadata": {},
"outputs": [],
"source": [
"# PRF pipelines\n",
"vector_prf = model >> index % 3 >> index.vec_loader() >> pyterrier_dr.VectorPrf() >> index\n",
"average_prf = model >> index % 3 >> index.vec_loader() >> pyterrier_dr.AveragePrf() >> index"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "389f1b82-9a81-4e76-8d43-8638ee24bd9f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Java started (triggered by _read_topics_singleline) and loaded: pyterrier.java, pyterrier.terrier.java [version=5.10 (build: craigm 2024-08-22 17:33), helper_version=0.0.8]\n",
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n",
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2e183ef4c5a3493dba2ed54d3600ba3a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n",
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a6551a6965c54aafbc59c45de6609334",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n",
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8a828590e0da4fba96609bb82cd48446",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n",
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c5c6c1658fc04fe2bc48e158ca6eb150",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/miniconda3/lib/python3.10/site-packages/pyterrier_dr/flex/core.py:90: UserWarning: performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\n",
" warn(\"performing exhaustive search with FlexIndex.np_retriever -- note that other FlexIndex retrievers may be faster\")\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "35514c524f7747ee8d01504ca6f25adc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"NumpyRetriever scoring: 0%| | 0/2159 [00:00<?, ?docbatch/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>nDCG@10</th>\n",
" <th>AP(rel=2)@100</th>\n",
" <th>nDCG@10 +</th>\n",
" <th>nDCG@10 -</th>\n",
" <th>nDCG@10 p-value</th>\n",
" <th>AP(rel=2)@100 +</th>\n",
" <th>AP(rel=2)@100 -</th>\n",
" <th>AP(rel=2)@100 p-value</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>TCT</td>\n",
" <td>0.720577</td>\n",
" <td>0.402403</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>TCT &gt;&gt; VectorPRF</td>\n",
" <td>0.732201</td>\n",
" <td>0.425013</td>\n",
" <td>24.0</td>\n",
" <td>11.0</td>\n",
" <td>0.203288</td>\n",
" <td>30.0</td>\n",
" <td>12.0</td>\n",
" <td>0.000902</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>TCT &gt;&gt; AveragePRF</td>\n",
" <td>0.729361</td>\n",
" <td>0.441232</td>\n",
" <td>25.0</td>\n",
" <td>14.0</td>\n",
" <td>0.505680</td>\n",
" <td>28.0</td>\n",
" <td>14.0</td>\n",
" <td>0.032259</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name nDCG@10 AP(rel=2)@100 nDCG@10 + nDCG@10 - \\\n",
"0 TCT 0.720577 0.402403 NaN NaN \n",
"1 TCT >> VectorPRF 0.732201 0.425013 24.0 11.0 \n",
"2 TCT >> AveragePRF 0.729361 0.441232 25.0 14.0 \n",
"\n",
" nDCG@10 p-value AP(rel=2)@100 + AP(rel=2)@100 - AP(rel=2)@100 p-value \n",
"0 NaN NaN NaN NaN \n",
"1 0.203288 30.0 12.0 0.000902 \n",
"2 0.505680 28.0 14.0 0.032259 "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"import pyterrier as pt\n",
"from pyterrier.measures import *\n",
"pt.Experiment(\n",
" [baseline, vector_prf, average_prf],\n",
" pt.get_dataset('msmarco_passage').get_topics('test-2019'),\n",
" pt.get_dataset('msmarco_passage').get_qrels('test-2019'),\n",
" [nDCG@10, AP(rel=2)@100],\n",
" names=[\"TCT\", \"TCT >> VectorPRF\", \"TCT >> AveragePRF\"],\n",
" baseline=0\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ecfd01d-7365-4f04-a3f1-563db0c5f920",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit be6a3cf

Please sign in to comment.