Skip to content

Commit

Permalink
FEATURE: Enhance embedding functionality with batch and image support. (
Browse files Browse the repository at this point in the history
#55)

* added litellm component

* support chat history

* trimmed comments

* dynamically get the model parameters

* added llm load balancer

* added the AI PR reviewer workflow

* fixed minor issues

* controlled session id

* refactored chat history and reused codes

* fixed minor logging issue

* reverted minor changes

* handle all LiteLLM inferences and embedding requests by the load balancer

* updated documents

* fix: remove useless import command

* refactor: restructure the LLM components

* Added support for batch embedding + image embedding

* Added init py

* typo

* fixed embedding

---------

Co-authored-by: alimosaed <[email protected]>
  • Loading branch information
cyrus2281 and alimosaed authored Nov 4, 2024
1 parent 948554f commit 3b4c99a
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 20 deletions.
9 changes: 6 additions & 3 deletions docs/components/langchain_embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ component_config:
```
{
text: <string>,
items: [
,
...
],
type: <string>
}
```
| Field | Required | Description |
| --- | --- | --- |
| text | True | The text to embed |
| type | False | The type of embedding to use: 'document' or 'query' - default is 'document' |
| items | True | A single element or a list of elements to embed |
| type | False | The type of embedding to use: 'document', 'query', or 'image' - default is 'document' |


## Component Output Schema
Expand Down
4 changes: 2 additions & 2 deletions examples/llm/litellm_embedding.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ flows:
payload_format: json

#
# Do an LLM request
# Do an Embedding request
#
- component_name: llm_request
- component_name: embedding_request
component_module: litellm_embeddings
component_config:
load_balancer:
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@
"input_schema": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The text to embed",
"items": {
"type": "array",
"description": "A single element or a list of elements to embed",
},
"type": {
"type": "string", # This is document or query
"description": "The type of embedding to use: 'document' or 'query' - default is 'document'",
"type": "string", # This is document, query, or image
"description": "The type of embedding to use: 'document', 'query', or 'image' - default is 'document'",
},
},
"required": ["text"],
"required": ["items"],
},
"output_schema": {
"type": "object",
Expand All @@ -66,13 +66,28 @@ def __init__(self, **kwargs):
super().__init__(info, **kwargs)

def invoke(self, message, data):
text = data["text"]
items = data["items"]
embedding_type = data.get("type", "document")

embeddings = None
items = [items] if type(items) != list else items

if embedding_type == "document":
embeddings = self.component.embed_documents([text])
return self.embed_documents(items)
elif embedding_type == "query":
embeddings = [self.component.embed_query(text)]
return self.embed_queries(items)
elif embedding_type == "image":
return self.embed_images(items)

def embed_documents(self, documents):
embeddings = self.component.embed_documents(documents)
return {"embeddings": embeddings}

def embed_queries(self, queries):
embeddings = []
for query in queries:
embeddings.append(self.component.embed_query(query))
return {"embeddings": embeddings}

return {"embedding": embeddings[0]}
def embed_images(self, images):
embeddings = self.component.embed_images(images)
return {"embeddings": embeddings}
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ def load_balance(self, messages, stream):

def invoke(self, message, data):
"""invoke the model"""
pass
pass
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@

class LiteLLMChatModel(LiteLLMChatModelBase):
def __init__(self, **kwargs):
super().__init__(info, **kwargs)
super().__init__(info, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,8 @@ def invoke(self, message, data):
input=items)

# Extract the embedding data from the response
embedding_data = response['data'][0]['embedding']
return {"embeddings": embedding_data}
embeddings = []
for embedding in response.get("data", []):
embeddings.append(embedding['embedding'])

return {"embeddings": embeddings}

0 comments on commit 3b4c99a

Please sign in to comment.