Skip to content

Commit

Permalink
Use GIST Embeddings (#28)
Browse files Browse the repository at this point in the history
* Use langchain-community

Signed-off-by: Aivin V. Solatorio <[email protected]>

* Update qdrant langchain_community

Signed-off-by: Aivin V. Solatorio <[email protected]>

* Use GIST Embedding for docs

Signed-off-by: Aivin V. Solatorio <[email protected]>

* Use GIST embedding for indicators

Signed-off-by: Aivin V. Solatorio <[email protected]>

---------

Signed-off-by: Aivin V. Solatorio <[email protected]>
  • Loading branch information
avsolatorio authored Feb 26, 2024
1 parent 5e000f5 commit a4ab398
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
8 changes: 6 additions & 2 deletions llm4data/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Base classes for embedding models."""
from typing import Union, Optional
from langchain import embeddings as langchain_embeddings
from langchain_community import embeddings as langchain_embeddings
from pydantic.main import ModelMetaclass
from qdrant_client.http import models
from pydantic.main import ModelMetaclass
Expand All @@ -17,6 +17,10 @@ class BaseEmbeddingModel:
"instruct": 768,
"all-MiniLM-L6-v2": 384,
"multi-qa-mpnet-base-dot-v1": 768,
"avsolatorio/GIST-all-MiniLM-L6-v2": 384,
"avsolatorio/GIST-small-Embedding-v0": 384,
"avsolatorio/GIST-Embedding-v0": 768,
"avsolatorio/GIST-large-Embedding-v0": 384,
}
model_name: str
distance: Union[str, models.Distance]
Expand All @@ -36,7 +40,7 @@ class BaseEmbeddingModel:

@property
def model_id(self):
return f"{self.data_type}_{self.model_name}_{self.collection_name}_{self.distance}_{self.size}_{self.max_tokens}_{self.is_instruct}"
return f"{self.data_type}_{self.model_name.replace('/', '_')}_{self.collection_name}_{self.distance}_{self.size}_{self.max_tokens}_{self.is_instruct}"

def dict(self):
return asdict(self)
Expand Down
2 changes: 1 addition & 1 deletion llm4data/embeddings/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_docs_embeddings():

if DOCS_EMBEDDINGS is None:
DOCS_EMBEDDINGS = DocsEmbedding(
model_name="all-MiniLM-L6-v2",
model_name="avsolatorio/GIST-small-Embedding-v0",
distance="Cosine",
embedding_cls="HuggingFaceEmbeddings",
is_instruct=False,
Expand Down
16 changes: 11 additions & 5 deletions llm4data/embeddings/indicators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@ def get_indicators_embeddings():
global INDICATORS_EMBEDDINGS

if INDICATORS_EMBEDDINGS is None:
# INDICATORS_EMBEDDINGS = IndicatorsEmbedding(
# model_name="instruct",
# distance="Cosine",
# embedding_cls="HuggingFaceInstructEmbeddings",
# is_instruct=True,
# embed_instruction="Represent the Economic Development description for retrieval; Input: ",
# query_instruction="Represent the Economic Development prompt for retrieving descriptions; Input: ",
# )
INDICATORS_EMBEDDINGS = IndicatorsEmbedding(
model_name="instruct",
model_name="avsolatorio/GIST-small-Embedding-v0",
distance="Cosine",
embedding_cls="HuggingFaceInstructEmbeddings",
is_instruct=True,
embed_instruction="Represent the Economic Development description for retrieval; Input: ",
query_instruction="Represent the Economic Development prompt for retrieving descriptions; Input: ",
embedding_cls="HuggingFaceEmbeddings",
is_instruct=False,
)

return INDICATORS_EMBEDDINGS
2 changes: 1 addition & 1 deletion llm4data/index/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from typing import Optional, Union
from langchain.vectorstores import Qdrant
from langchain_community.vectorstores import Qdrant
import qdrant_client
from qdrant_client.http import models

Expand Down

0 comments on commit a4ab398

Please sign in to comment.