diff --git a/docs/sphinx/source/examples/simplified-retrieval-with-colpali-vlm_Vespa-cloud.ipynb b/docs/sphinx/source/examples/simplified-retrieval-with-colpali-vlm_Vespa-cloud.ipynb index 2bb9fd57..ad185d43 100644 --- a/docs/sphinx/source/examples/simplified-retrieval-with-colpali-vlm_Vespa-cloud.ipynb +++ b/docs/sphinx/source/examples/simplified-retrieval-with-colpali-vlm_Vespa-cloud.ipynb @@ -103,7 +103,8 @@ " process_images,\n", " process_queries,\n", ")\n", - "from colpali_engine.utils.image_utils import scale_image, get_base64_image" + "from colpali_engine.utils.image_utils import scale_image, get_base64_image\n", + "from torch.amp import autocast" ] }, { @@ -133,13 +134,13 @@ "source": [ "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", - " type = torch.bfloat16\n", + " dtype = torch.bfloat16\n", "elif torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", - " type = torch.float32\n", + " dtype = torch.float32\n", "else:\n", " device = torch.device(\"cpu\")\n", - " type = torch.float32" + " dtype = torch.float32" ] }, { @@ -590,9 +591,10 @@ " )\n", " for batch_doc in tqdm(dataloader):\n", " with torch.no_grad():\n", - " batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}\n", - " embeddings_doc = model(**batch_doc)\n", - " page_embeddings.extend(list(torch.unbind(embeddings_doc.to(\"cpu\"))))\n", + " with autocast(device_type=device.type):\n", + " batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}\n", + " embeddings_doc = model(**batch_doc)\n", + " page_embeddings.extend(list(torch.unbind(embeddings_doc.to(\"cpu\"))))\n", " pdf[\"embeddings\"] = page_embeddings" ] }, @@ -974,9 +976,10 @@ "qs = []\n", "for batch_query in dataloader:\n", " with torch.no_grad():\n", - " batch_query = {k: v.to(model.device) for k, v in batch_query.items()}\n", - " embeddings_query = model(**batch_query)\n", - " qs.extend(list(torch.unbind(embeddings_query.to(\"cpu\"))))" + " with autocast(device_type=device.type):\n", + " batch_query = {k: v.to(model.device) for k, v in batch_query.items()}\n", + " embeddings_query = model(**batch_query)\n", + " qs.extend(list(torch.unbind(embeddings_query.to(\"cpu\"))))" ] }, {