Skip to content

Commit

Permalink
update to match
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasht86 committed Nov 15, 2024
1 parent 6ab0238 commit f5ac4d1
Showing 1 changed file with 54 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1265,14 +1265,13 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": null,
"id": "5b7acbb6",
"metadata": {
"id": "5b7acbb6"
},
"outputs": [],
"source": [
"# Define the Vespa schema\n",
"colpali_schema = Schema(\n",
" name=VESPA_SCHEMA_NAME,\n",
" document=Document(\n",
Expand All @@ -1293,9 +1292,7 @@
" index=\"enable-bm25\",\n",
" ),\n",
" Field(name=\"page_number\", type=\"int\", indexing=[\"summary\", \"attribute\"]),\n",
" Field(\n",
" name=\"blur_image\", type=\"raw\", indexing=[\"summary\"]\n",
" ), # We store the images as base64-encoded strings\n",
" Field(name=\"blur_image\", type=\"raw\", indexing=[\"summary\"]),\n",
" Field(name=\"full_image\", type=\"raw\", indexing=[\"summary\"]),\n",
" Field(\n",
" name=\"text\",\n",
Expand All @@ -1305,31 +1302,29 @@
" index=\"enable-bm25\",\n",
" ),\n",
" Field(\n",
" name=\"embedding\", # The page image embeddings are stored as binary tensors (represented by signed 8-bit integers)\n",
" name=\"embedding\",\n",
" type=\"tensor<int8>(patch{}, v[16])\",\n",
" indexing=[\n",
" \"attribute\",\n",
" \"index\",\n",
" ],\n",
" ann=HNSW( # Since we are using binary embeddings, we use the HNSW algorithm with Hamming distance as the metric, which is highly efficient in Vespa, see https://blog.vespa.ai/scaling-colpali-to-billions/\n",
" ann=HNSW(\n",
" distance_metric=\"hamming\",\n",
" max_links_per_node=32,\n",
" neighbors_to_explore_at_insert=400,\n",
" ),\n",
" ),\n",
" Field(\n",
" name=\"questions\", # We store the generated questions and queries as arrays of strings for each document\n",
" name=\"questions\",\n",
" type=\"array<string>\",\n",
" indexing=[\"summary\", \"index\", \"attribute\"],\n",
" index=\"enable-bm25\",\n",
" stemming=\"best\",\n",
" indexing=[\"summary\", \"attribute\"],\n",
" summary=Summary(fields=[\"matched-elements-only\"]),\n",
" ),\n",
" Field(\n",
" name=\"queries\",\n",
" type=\"array<string>\",\n",
" indexing=[\"summary\", \"index\", \"attribute\"],\n",
" index=\"enable-bm25\",\n",
" stemming=\"best\",\n",
" indexing=[\"summary\", \"attribute\"],\n",
" summary=Summary(fields=[\"matched-elements-only\"]),\n",
" ),\n",
" ]\n",
" ),\n",
Expand Down Expand Up @@ -1389,7 +1384,7 @@
"]\n",
"\n",
"# Define the 'bm25' rank profile\n",
"colpali_bm25_profile = RankProfile(\n",
"bm25 = RankProfile(\n",
" name=\"bm25\",\n",
" inputs=[(\"query(qt)\", \"tensor<float>(querytoken{}, v[128])\")],\n",
" first_phase=\"bm25(title) + bm25(text)\",\n",
Expand All @@ -1407,14 +1402,27 @@
" )\n",
"\n",
"\n",
"colpali_schema.add_rank_profile(colpali_bm25_profile)\n",
"colpali_schema.add_rank_profile(with_quantized_similarity(colpali_bm25_profile))\n",
"colpali_schema.add_rank_profile(bm25)\n",
"colpali_schema.add_rank_profile(with_quantized_similarity(bm25))\n",
"\n",
"# Update the 'default' rank profile\n",
"colpali_profile = RankProfile(\n",
" name=\"default\",\n",
" inputs=[(\"query(qt)\", \"tensor<float>(querytoken{}, v[128])\")],\n",
" first_phase=\"bm25_score\",\n",
"\n",
"# Update the 'colpali' rank profile\n",
"input_query_tensors = []\n",
"MAX_QUERY_TERMS = 64\n",
"for i in range(MAX_QUERY_TERMS):\n",
" input_query_tensors.append((f\"query(rq{i})\", \"tensor<int8>(v[16])\"))\n",
"\n",
"input_query_tensors.extend(\n",
" [\n",
" (\"query(qt)\", \"tensor<float>(querytoken{}, v[128])\"),\n",
" (\"query(qtb)\", \"tensor<int8>(querytoken{}, v[16])\"),\n",
" ]\n",
")\n",
"\n",
"colpali = RankProfile(\n",
" name=\"colpali\",\n",
" inputs=input_query_tensors,\n",
" first_phase=\"max_sim_binary\",\n",
" second_phase=SecondPhaseRanking(expression=\"max_sim\", rerank_count=10),\n",
" functions=mapfunctions\n",
" + [\n",
Expand All @@ -1432,30 +1440,33 @@
" )\n",
" \"\"\",\n",
" ),\n",
" Function(name=\"bm25_score\", expression=\"bm25(title) + bm25(text)\"),\n",
" Function(\n",
" name=\"max_sim_binary\",\n",
" expression=\"\"\"\n",
" sum(\n",
" reduce(\n",
" 1 / (1 + sum(\n",
" hamming(query(qtb), attribute(embedding)), v)\n",
" ),\n",
" max, patch\n",
" ),\n",
" querytoken\n",
" )\n",
" \"\"\",\n",
" ),\n",
" ],\n",
")\n",
"colpali_schema.add_rank_profile(colpali_profile)\n",
"colpali_schema.add_rank_profile(with_quantized_similarity(colpali_profile))\n",
"\n",
"# Update the 'retrieval-and-rerank' rank profile\n",
"input_query_tensors = []\n",
"MAX_QUERY_TERMS = 64\n",
"for i in range(MAX_QUERY_TERMS):\n",
" input_query_tensors.append((f\"query(rq{i})\", \"tensor<int8>(v[16])\"))\n",
"\n",
"input_query_tensors.extend(\n",
" [\n",
" (\"query(qt)\", \"tensor<float>(querytoken{}, v[128])\"),\n",
" (\"query(qtb)\", \"tensor<int8>(querytoken{}, v[16])\"),\n",
" ]\n",
")\n",
"colpali_schema.add_rank_profile(colpali)\n",
"colpali_schema.add_rank_profile(with_quantized_similarity(colpali))\n",
"\n",
"colpali_retrieval_profile = RankProfile(\n",
" name=\"retrieval-and-rerank\",\n",
"# Update the 'hybrid' rank profile\n",
"hybrid = RankProfile(\n",
" name=\"hybrid\",\n",
" inputs=input_query_tensors,\n",
" first_phase=\"max_sim_binary\",\n",
" second_phase=SecondPhaseRanking(expression=\"max_sim\", rerank_count=10),\n",
" second_phase=SecondPhaseRanking(\n",
" expression=\"max_sim + 2 * (bm25(text) + bm25(title))\", rerank_count=10\n",
" ),\n",
" functions=mapfunctions\n",
" + [\n",
" Function(\n",
Expand Down Expand Up @@ -1488,8 +1499,8 @@
" ),\n",
" ],\n",
")\n",
"colpali_schema.add_rank_profile(colpali_retrieval_profile)\n",
"colpali_schema.add_rank_profile(with_quantized_similarity(colpali_retrieval_profile))"
"colpali_schema.add_rank_profile(hybrid)\n",
"colpali_schema.add_rank_profile(with_quantized_similarity(hybrid))"
]
},
{
Expand Down

0 comments on commit f5ac4d1

Please sign in to comment.