Skip to content

Commit

Permalink
Make embeddings optional in weaviate component (#791)
Browse files Browse the repository at this point in the history
Fixes [Make embedding
optional#69](ml6team/fondant-usecase-RAG#69)
  • Loading branch information
PhilippeMoussalli authored Jan 18, 2024
1 parent 08834b5 commit 538aa63
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 34 deletions.
97 changes: 93 additions & 4 deletions components/index_weaviate/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,96 @@

<a id="index_weaviate#description"></a>
## Description
Component that takes embeddings of text snippets and indexes them into a weaviate vector database.
Component that takes text or embeddings of text snippets and indexes them into a Weaviate vector database.

To run the component with text snippets as input, the component needs to be connected to a previous component that outputs text snippets.

### Running with text as input

```python
import pyarrow as pa
from fondant.pipeline import Pipeline

pipeline = Pipeline(name="my_pipeline", base_path="path/to/pipeline")

dataset = pipeline.read(
"load_from_csv",
arguments={
"dataset_uri": "path/to/dataset.csv",
},
produces={
"text": pa.string(),
}
)

dataset = dataset.apply(
"index_weaviate",
arguments={
"weaviate_url": "http://localhost:8080",
"class_name": "my_class",
"vectorizer": "text2vec-openai",
"additional_headers" : {
"X-OpenAI-Api-Key": "YOUR-OPENAI-API-KEY"
}
},
consumes={
"text": "text"
}
)
```

### Running with embedding as input

```python
import pyarrow as pa
from fondant.pipeline import Pipeline

pipeline = Pipeline(name="my_pipeline",base_path="path/to/pipeline")

dataset = pipeline.read(
"load_from_csv",
arguments={
"dataset_uri": "path/to/dataset.csv",
},
produces={
"text": pa.string(),
}
)

dataset = dataset.apply(
"embed_text",
arguments={...},
consumes={
"text": "text",
},
)

dataset = dataset.apply(
"index_weaviate",
arguments={
"weaviate_url": "http://localhost:8080",
"class_name": "my_class",
},
consumes={
"embedding": "embedding"
}
)
```


<a id="index_weaviate#inputs_outputs"></a>
## Inputs / outputs

<a id="index_weaviate#consumes"></a>
### Consumes
**This component consumes:**

- text: string
- embedding: list<item: float>
**This component can consume additional fields**
- <field_name>: <dataset_field_name>
This defines a mapping to update the fields consumed by the operation as defined in the component spec.
The keys are the names of the fields to be received by the component, while the values are
the name of the field to map from the input dataset

See the usage example below on how to define a field name for additional fields.



Expand All @@ -36,7 +115,10 @@ The component takes the following arguments to alter its behavior:
| num_workers | int | The maximal number of concurrent threads to run batch import.Parameter of weaviate.batch.Batch().configure(). | 2 |
| overwrite | bool | Whether to overwrite/ re-create the existing weaviate class and its embeddings. | / |
| class_name | str | The name of the weaviate class that will be created and used to store the embeddings.Should follow the weaviate naming conventions. | / |
| additional_config | dict | Additional configuration to pass to the weaviate client. | / |
| additional_headers | dict | Additional headers to pass to the weaviate client. | / |
| vectorizer | str | Which vectorizer to use. You can find the available vectorizers in the weaviate documentation: https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modulesSet this to None if you want to insert your own embeddings. | / |
| module_config | dict | The configuration of the vectorizer module.You can find the available configuration options in the weaviate documentation: https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modulesSet this to None if you want to insert your own embeddings. | / |

<a id="index_weaviate#usage"></a>
## Usage
Expand All @@ -63,8 +145,15 @@ dataset.write(
# "num_workers": 2,
# "overwrite": False,
# "class_name": ,
# "additional_config": {},
# "additional_headers": {},
# "vectorizer": ,
# "module_config": {},
},
consumes={
<field_name>: <dataset_field_name>,
..., # Add fields
},
)
```

Expand Down
107 changes: 98 additions & 9 deletions components/index_weaviate/fondant_component.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,88 @@
name: Index Weaviate
description: Component that takes embeddings of text snippets and indexes them into a weaviate vector database.
description: |
Component that takes text or embeddings of text snippets and indexes them into a Weaviate vector database.
To run the component with text snippets as input, the component needs to be connected to a previous component that outputs text snippets.
### Running with text as input
```python
import pyarrow as pa
from fondant.pipeline import Pipeline
pipeline = Pipeline(name="my_pipeline", base_path="path/to/pipeline")
dataset = pipeline.read(
"load_from_csv",
arguments={
"dataset_uri": "path/to/dataset.csv",
},
produces={
"text": pa.string(),
}
)
dataset = dataset.apply(
"index_weaviate",
arguments={
"weaviate_url": "http://localhost:8080",
"class_name": "my_class",
"vectorizer": "text2vec-openai",
"additional_headers" : {
"X-OpenAI-Api-Key": "YOUR-OPENAI-API-KEY"
}
},
consumes={
"text": "text"
}
)
```
### Running with embedding as input
```python
import pyarrow as pa
from fondant.pipeline import Pipeline
pipeline = Pipeline(name="my_pipeline",base_path="path/to/pipeline")
dataset = pipeline.read(
"load_from_csv",
arguments={
"dataset_uri": "path/to/dataset.csv",
},
produces={
"text": pa.string(),
}
)
dataset = dataset.apply(
"embed_text",
arguments={...},
consumes={
"text": "text",
},
)
dataset = dataset.apply(
"index_weaviate",
arguments={
"weaviate_url": "http://localhost:8080",
"class_name": "my_class",
},
consumes={
"embedding": "embedding"
}
)
```
image: fndnt/index_weaviate:dev
tags:
- Data writing

consumes:
text:
type: string
embedding:
type: array
items:
type: float32
additionalProperties: true


args:
weaviate_url:
Expand Down Expand Up @@ -44,12 +116,29 @@ args:
The name of the weaviate class that will be created and used to store the embeddings.
Should follow the weaviate naming conventions.
type: str
additional_config:
description: |
Additional configuration to pass to the weaviate client.
type: dict
default: None
additional_headers:
description: |
Additional headers to pass to the weaviate client.
type: dict
default: {}
vectorizer:
description: |
Which vectorizer to use.
You can find the available vectorizers in the weaviate documentation: https://weaviate
.io/developers/weaviate/modules/retriever-vectorizer-modules
Set this to None if you want to insert your own embeddings.
type: str
default: None

default: {}
module_config:
description: |
The configuration of the vectorizer module.
You can find the available configuration options in the weaviate documentation: https://weaviate
.io/developers/weaviate/modules/retriever-vectorizer-modules
Set this to None if you want to insert your own embeddings.
type: dict
default: {}
84 changes: 63 additions & 21 deletions components/index_weaviate/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@ def __init__(
num_workers: int,
overwrite: bool,
class_name: str,
additional_config: t.Optional[dict],
additional_headers: t.Optional[dict],
vectorizer: t.Optional[str],
module_config: t.Optional[dict],
**kwargs,
):
self.client = weaviate.Client(weaviate_url)
self.client = weaviate.Client(
url=weaviate_url,
additional_config=additional_config if additional_config else None,
additional_headers=additional_headers if additional_headers else None,
)

self.client.batch.configure(
batch_size=batch_size,
Expand All @@ -31,46 +38,81 @@ def __init__(
)

self.class_name = class_name
self.vectorizer = vectorizer
self.module_config = module_config

if overwrite:
self.client.schema.delete_class(self.class_name)

if not self.client.schema.exists(self.class_name):
self.client.schema.create_class(
class_schema = self.create_class_schema()
self.client.schema.create_class(class_schema)

def validate_component(self, dataframe: dd.DataFrame) -> None:
if "embedding" not in dataframe.columns and self.vectorizer is None:
msg = "If vectorizer is not specified, dataframe must contain an 'embedding' column."
raise ValueError(
msg,
)

if self.vectorizer is not None and not self.module_config:
msg = "If vectorizer is specified, module_config must be specified as well."
raise ValueError(
msg,
)

def create_class_schema(self) -> t.Dict[str, t.Any]:
class_schema: t.Dict[str, t.Any] = {
"class": self.class_name,
"properties": [
{
"class": class_name,
"properties": [
{
"name": "passage",
"dataType": ["text"],
},
{ # id of the passage in the passage dataset
# not to mix up with weaviate's uuid
"name": "id_",
"dataType": ["text"],
},
],
"vectorizer": vectorizer,
"name": "passage",
"dataType": ["text"],
},
)
{ # id of the passage in the passage dataset
# not to mix up with weaviate's uuid
"name": "id_",
"dataType": ["text"],
},
],
}

if self.vectorizer is not None:
class_schema["vectorizer"] = self.vectorizer

if self.module_config is not None:
class_schema["moduleConfig"] = self.module_config

return class_schema

def teardown(self) -> None:
del self.client

def write(self, dataframe: dd.DataFrame) -> None:
self.validate_component(dataframe)

with self.client.batch as batch:
for part in tqdm(
dataframe.partitions,
desc="Processing partitions",
total=dataframe.npartitions,
):
df = part.compute()

for row in df.itertuples():
properties = {
"id_": str(row.Index),
"passage": row.text,
}
batch.add_data_object(
data_object=properties,
class_name=self.class_name,
vector=row.embedding,
)

if self.vectorizer is None:
batch.add_data_object(
data_object=properties,
class_name=self.class_name,
vector=row.embedding,
)
else:
batch.add_data_object(
data_object=properties,
class_name=self.class_name,
)
3 changes: 3 additions & 0 deletions components/index_weaviate/tests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def test_index_weaviate_component(monkeypatch):
overwrite=True,
class_name="TestClass",
vectorizer=None,
additional_headers=None,
additional_config=None,
module_config=None,
)

index_component.write(dask_df)
Expand Down

0 comments on commit 538aa63

Please sign in to comment.