Skip to content

Commit

Permalink
Merge pull request #979 from vespa-engine/thomasht86/fix-colpali-note…
Browse files Browse the repository at this point in the history
…book

Thomasht86/update colpali notebooks
  • Loading branch information
thomasht86 authored Nov 22, 2024
2 parents b38751d + 9d1cb6d commit 6bf4808
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
"from colpali_engine.utils.colpali_processing_utils import (\n",
" process_images,\n",
" process_queries,\n",
")"
")\n",
"from torch.amp import autocast"
]
},
{
Expand Down Expand Up @@ -88,13 +89,13 @@
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
" type = torch.bfloat16\n",
" dtype = torch.bfloat16\n",
"elif torch.backends.mps.is_available():\n",
" device = torch.device(\"mps\")\n",
" type = torch.float32\n",
" dtype = torch.float32\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
" type = torch.float32"
" dtype = torch.float32"
]
},
{
Expand Down Expand Up @@ -286,7 +287,7 @@
"source": [
"model_name = \"vidore/colpali-v1.2\"\n",
"model = ColPali.from_pretrained(\n",
" \"vidore/colpaligemma-3b-pt-448-base\", torch_dtype=type\n",
" \"vidore/colpaligemma-3b-pt-448-base\", torch_dtype=dtype\n",
").eval()\n",
"model.load_adapter(model_name)\n",
"model = model.eval()\n",
Expand Down Expand Up @@ -356,9 +357,10 @@
"embeddings = []\n",
"for batch_doc in tqdm(dataloader):\n",
" with torch.no_grad():\n",
" batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}\n",
" embeddings_doc = model(**batch_doc)\n",
" embeddings.extend(list(torch.unbind(embeddings_doc.to(\"cpu\"))))"
" with autocast(device_type=device.type):\n",
" batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}\n",
" embeddings_doc = model(**batch_doc)\n",
" embeddings.extend(list(torch.unbind(embeddings_doc.to(\"cpu\"))))"
]
},
{
Expand Down Expand Up @@ -392,9 +394,10 @@
"query_embeddings = []\n",
"for batch_query in tqdm(dataloader):\n",
" with torch.no_grad():\n",
" batch_query = {k: v.to(model.device) for k, v in batch_query.items()}\n",
" embeddings_query = model(**batch_query)\n",
" query_embeddings.extend(list(torch.unbind(embeddings_query.to(\"cpu\"))))"
" with autocast(device_type=device.type):\n",
" batch_query = {k: v.to(model.device) for k, v in batch_query.items()}\n",
" embeddings_query = model(**batch_query)\n",
" query_embeddings.extend(list(torch.unbind(embeddings_query.to(\"cpu\"))))"
]
},
{
Expand Down Expand Up @@ -6238,4 +6241,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"\n",
"[ColPali: Efficient Document Retrieval with Vision Language Models Manuel Faysse, Hugues Sibille, Tony Wu, Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo](https://arxiv.org/abs/2407.01449v2)\n",
"\n",
"ColPail is a combination of [ColBERT](https://blog.vespa.ai/announcing-colbert-embedder-in-vespa/) \n",
"and [PailGemma](https://huggingface.co/blog/paligemma):\n",
"ColPali is a combination of [ColBERT](https://blog.vespa.ai/announcing-colbert-embedder-in-vespa/) \n",
"and [PaliGemma](https://huggingface.co/blog/paligemma):\n",
"\n",
">ColPali is enabled by the latest advances in Vision Language Models, notably the PaliGemma model from the Google Zürich team, and leverages multi-vector retrieval through late interaction mechanisms as proposed in ColBERT by Omar Khattab.\n",
"\n",
Expand Down Expand Up @@ -61,7 +61,7 @@
"\n",
"Let us get started. \n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vespa-engine/pyvespa/blob/master/docs/sphinx/source/examples/colpali-document-retrieval-vision-language-models.ipynb)\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vespa-engine/pyvespa/blob/master/docs/sphinx/source/examples/colpali-document-retrieval-vision-language-models-cloud.ipynb)\n",
"\n",
"\n",
"Install dependencies: \n",
Expand Down Expand Up @@ -100,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"id": "qKFOvdo5nCVl"
},
Expand All @@ -118,7 +118,8 @@
" process_images,\n",
" process_queries,\n",
")\n",
"from colpali_engine.utils.image_utils import scale_image, get_base64_image"
"from colpali_engine.utils.image_utils import scale_image, get_base64_image\n",
"from torch.amp import autocast"
]
},
{
Expand All @@ -142,19 +143,19 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
" type = torch.bfloat16\n",
" dtype = torch.bfloat16\n",
"elif torch.backends.mps.is_available():\n",
" device = torch.device(\"mps\")\n",
" type = torch.float32\n",
" dtype = torch.float32\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
" type = torch.float32"
" dtype = torch.float32"
]
},
{
Expand Down Expand Up @@ -339,7 +340,7 @@
"source": [
"model_name = \"vidore/colpali-v1.2\"\n",
"model = ColPali.from_pretrained(\n",
" \"vidore/colpaligemma-3b-pt-448-base\", torch_dtype=type\n",
" \"vidore/colpaligemma-3b-pt-448-base\", torch_dtype=dtype\n",
").eval()\n",
"model.load_adapter(model_name)\n",
"model = model.eval()\n",
Expand Down Expand Up @@ -497,7 +498,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand Down Expand Up @@ -526,9 +527,10 @@
" )\n",
" for batch_doc in tqdm(dataloader):\n",
" with torch.no_grad():\n",
" batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}\n",
" embeddings_doc = model(**batch_doc)\n",
" page_embeddings.extend(list(torch.unbind(embeddings_doc.to(\"cpu\"))))\n",
" with autocast(device_type=device.type):\n",
" batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}\n",
" embeddings_doc = model(**batch_doc)\n",
" page_embeddings.extend(list(torch.unbind(embeddings_doc.cpu())))\n",
" pdf[\"embeddings\"] = page_embeddings"
]
},
Expand Down Expand Up @@ -927,7 +929,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"metadata": {
"id": "NxeDd3mcYDpL"
},
Expand All @@ -942,9 +944,10 @@
"qs = []\n",
"for batch_query in dataloader:\n",
" with torch.no_grad():\n",
" batch_query = {k: v.to(model.device) for k, v in batch_query.items()}\n",
" embeddings_query = model(**batch_query)\n",
" qs.extend(list(torch.unbind(embeddings_query.to(\"cpu\"))))"
" with autocast(device_type=device.type):\n",
" batch_query = {k: v.to(model.device) for k, v in batch_query.items()}\n",
" embeddings_query = model(**batch_query)\n",
" qs.extend(list(torch.unbind(embeddings_query.cpu())))"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@
" process_images,\n",
" process_queries,\n",
")\n",
"from colpali_engine.utils.image_utils import scale_image, get_base64_image"
"from colpali_engine.utils.image_utils import scale_image, get_base64_image\n",
"from torch.amp import autocast"
]
},
{
Expand Down Expand Up @@ -133,13 +134,13 @@
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
" type = torch.bfloat16\n",
" dtype = torch.bfloat16\n",
"elif torch.backends.mps.is_available():\n",
" device = torch.device(\"mps\")\n",
" type = torch.float32\n",
" dtype = torch.float32\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
" type = torch.float32"
" dtype = torch.float32"
]
},
{
Expand Down Expand Up @@ -331,7 +332,7 @@
"source": [
"model_name = \"vidore/colpali-v1.2\"\n",
"model = ColPali.from_pretrained(\n",
" \"vidore/colpaligemma-3b-pt-448-base\", torch_dtype=type\n",
" \"vidore/colpaligemma-3b-pt-448-base\", torch_dtype=dtype\n",
").eval()\n",
"model.load_adapter(model_name)\n",
"model = model.eval()\n",
Expand Down Expand Up @@ -432,14 +433,15 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"id": "YaDInfmT3Tbu"
},
"outputs": [],
"source": [
"for pdf in sample_pdfs:\n",
" page_images, page_texts = get_pdf_images(pdf[\"url\"])\n",
"\n",
" pdf[\"images\"] = page_images\n",
" pdf[\"texts\"] = page_texts"
]
Expand Down Expand Up @@ -590,9 +592,10 @@
" )\n",
" for batch_doc in tqdm(dataloader):\n",
" with torch.no_grad():\n",
" batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}\n",
" embeddings_doc = model(**batch_doc)\n",
" page_embeddings.extend(list(torch.unbind(embeddings_doc.to(\"cpu\"))))\n",
" with autocast(device_type=device.type):\n",
" batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}\n",
" embeddings_doc = model(**batch_doc)\n",
" page_embeddings.extend(list(torch.unbind(embeddings_doc.to(\"cpu\"))))\n",
" pdf[\"embeddings\"] = page_embeddings"
]
},
Expand Down Expand Up @@ -974,9 +977,10 @@
"qs = []\n",
"for batch_query in dataloader:\n",
" with torch.no_grad():\n",
" batch_query = {k: v.to(model.device) for k, v in batch_query.items()}\n",
" embeddings_query = model(**batch_query)\n",
" qs.extend(list(torch.unbind(embeddings_query.to(\"cpu\"))))"
" with autocast(device_type=device.type):\n",
" batch_query = {k: v.to(model.device) for k, v in batch_query.items()}\n",
" embeddings_query = model(**batch_query)\n",
" qs.extend(list(torch.unbind(embeddings_query.to(\"cpu\"))))"
]
},
{
Expand Down

0 comments on commit 6bf4808

Please sign in to comment.