Skip to content

Commit

Permalink
Simplify component init interface (#819)
Browse files Browse the repository at this point in the history
Removes the need to having to define `**kwargs` when the user defines
the component `__init__` function. This was needed before to pass the
the `produces` and `consumes` to some read/write components without the
user having to explicitly define them. This was resolved by removing
those two arguments from the component signature.

Tested on a pipeline with load/write components.
  • Loading branch information
PhilippeMoussalli authored Jan 29, 2024
1 parent c45abe3 commit d97246d
Show file tree
Hide file tree
Showing 37 changed files with 31 additions and 98 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ from fondant.component import PandasTransformComponent
class ExampleComponent(PandasTransformComponent):
def __init__(self, *, argument1, argument2, **kwargs) -> None:
def __init__(self, *, argument1, argument2) -> None:
"""
Args:
argumentX: An argument passed to the component
Expand Down
1 change: 0 additions & 1 deletion components/caption_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def __init__(
model_id: str,
batch_size: int,
max_new_tokens: int,
**kwargs,
):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device: {self.device}")
Expand Down
2 changes: 0 additions & 2 deletions components/chunk_text/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
chunk_strategy: t.Optional[str],
chunk_kwargs: t.Optional[dict],
language_text_splitter: t.Optional[str],
**kwargs,
):
"""
Args:
Expand All @@ -59,7 +58,6 @@ def __init__(
https://python.langchain.com/docs/modules/data_connection/document_transformers/
code_splitter
for more information on supported languages.
kwargs: Unhandled keyword arguments passed in by Fondant.
"""
self.chunk_strategy = chunk_strategy
self.chunk_kwargs = chunk_kwargs
Expand Down
2 changes: 0 additions & 2 deletions components/crop_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,11 @@ def __init__(
*,
cropping_threshold: int,
padding: int,
**kwargs,
) -> None:
"""
Args:
cropping_threshold (int): threshold parameter used for detecting borders
padding (int): padding for the image cropping.
kwargs: Unhandled keyword arguments passed in by Fondant.
"""
self.cropping_threshold = cropping_threshold
self.padding = padding
Expand Down
2 changes: 0 additions & 2 deletions components/download_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(
resize_only_if_bigger: bool,
min_image_size: int,
max_aspect_ratio: float,
**kwargs,
):
"""Component that downloads images from a list of URLs and executes filtering and resizing.
Expand All @@ -47,7 +46,6 @@ def __init__(
resize_only_if_bigger: If True, resize only if image is bigger than image_size.
min_image_size: Minimum size of the images.
max_aspect_ratio: Maximum aspect ratio of the images.
kwargs: Unhandled keyword arguments passed in by Fondant
Returns:
Dask dataframe
Expand Down
1 change: 0 additions & 1 deletion components/embed_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(
*,
model_id: str,
batch_size: int,
**kwargs,
):
"""
Args:
Expand Down
1 change: 0 additions & 1 deletion components/embed_text/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(
model: str,
api_keys: dict,
auth_kwargs: dict,
**kwargs,
):
to_env_vars(api_keys)

Expand Down
3 changes: 0 additions & 3 deletions components/evaluate_ragas/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@ def __init__(
llm_class_name: str,
llm_kwargs: dict,
produces: t.Dict[str, t.Any],
**kwargs,
) -> None:
"""
Args:
llm_module_name: Module from which the LLM is imported. Defaults to
langchain.chat_models
llm_class_name: Name of the selected llm. Defaults to ChatOpenAI
module: Module from which the LLM is imported. Defaults to langchain.llms
llm_kwargs: Arguments of the selected llm
produces: RAGAS metrics to compute.
kwargs: Unhandled keyword arguments passed in by Fondant.
"""
self.llm = self.extract_llm(
llm_module_name=llm_module_name,
Expand Down
2 changes: 0 additions & 2 deletions components/filter_image_resolution/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ def __init__(
*,
min_image_dim: int,
max_aspect_ratio: float,
**kwargs,
) -> None:
"""
Args:
min_image_dim: minimum image dimension.
max_aspect_ratio: maximum aspect ratio.
kwargs: Unhandled keyword arguments passed in by Fondant.
"""
self.min_image_dim = min_image_dim
self.max_aspect_ratio = max_aspect_ratio
Expand Down
3 changes: 1 addition & 2 deletions components/filter_language/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ def is_language(self, row):
class LanguageFilterComponent(PandasTransformComponent):
"""Component that filter columns based on provided language."""

def __init__(self, *, language, **kwargs):
def __init__(self, *, language):
"""Setup language filter component.
Args:
language: Only keep text passages which are in the provided language.
kwargs: Unhandled keyword arguments passed in by Fondant
"""
self.lang_detector = LanguageIdentification(language)

Expand Down
3 changes: 1 addition & 2 deletions components/filter_text_length/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
class FilterTextLengthComponent(PandasTransformComponent):
"""A component that filters out text based on their length."""

def __init__(self, *, min_characters_length: int, min_words_length: int, **kwargs):
def __init__(self, *, min_characters_length: int, min_words_length: int):
"""Setup component.
Args:
min_characters_length: minimum number of characters
min_words_length: minimum number of words.
kwargs: Unhandled keyword arguments passed in by Fondant
"""
self.min_characters_length = min_characters_length
self.min_words_length = min_words_length
Expand Down
3 changes: 1 addition & 2 deletions components/generate_minhash/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ def compute_minhash(shingles: list) -> np.ndarray:
class MinHashGeneratorComponent(PandasTransformComponent):
"""Component generates minhashes of text."""

def __init__(self, *, shingle_ngram_size: int, **kwargs):
def __init__(self, *, shingle_ngram_size: int):
"""Setup component.
Args:
shingle_ngram_size: Defines size of ngram used for the shingle generation.
kwargs: Unhandled keyword arguments passed in by Fondant
"""
self.shingle_ngram_size = shingle_ngram_size

Expand Down
2 changes: 0 additions & 2 deletions components/index_aws_opensearch/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(
use_ssl: Optional[bool],
verify_certs: Optional[bool],
pool_maxsize: Optional[int],
**kwargs,
):
session = boto3.Session()
credentials = session.get_credentials()
Expand All @@ -35,7 +34,6 @@ def __init__(
verify_certs=verify_certs,
connection_class=RequestsHttpConnection,
pool_maxsize=pool_maxsize,
**kwargs,
)
self.create_index(index_body)

Expand Down
1 change: 0 additions & 1 deletion components/index_qdrant/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(
host: Optional[str] = None,
path: Optional[str] = None,
force_disable_check_same_thread: bool = False,
**kwargs,
):
"""Initialize the IndexQdrantComponent with the component parameters."""
self.client = QdrantClient(
Expand Down
1 change: 0 additions & 1 deletion components/index_weaviate/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
additional_headers: t.Optional[dict],
vectorizer: t.Optional[str],
module_config: t.Optional[dict],
**kwargs,
):
self.client = weaviate.Client(
url=weaviate_url,
Expand Down
5 changes: 0 additions & 5 deletions components/load_from_csv/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import dask.dataframe as dd
import pandas as pd
from fondant.component import DaskLoadComponent
from fondant.core.schema import Field

logger = logging.getLogger(__name__)

Expand All @@ -13,13 +12,11 @@ class CSVReader(DaskLoadComponent):
def __init__(
self,
*,
produces: t.Dict[str, Field],
dataset_uri: str,
column_separator: str,
column_name_mapping: t.Optional[dict],
n_rows_to_load: t.Optional[int],
index_column: t.Optional[str],
**kwargs,
) -> None:
"""
Args:
Expand All @@ -33,14 +30,12 @@ def __init__(
index_column: Column to set index to in the load component,
if not specified a default globally unique index will
be set.
kwargs: Unhandled keyword arguments passed in by Fondant.
"""
self.dataset_uri = dataset_uri
self.column_separator = column_separator
self.column_name_mapping = column_name_mapping
self.n_rows_to_load = n_rows_to_load
self.index_column = index_column
self.produces = produces

def get_columns_to_keep(self) -> t.List[str]:
# Only read required columns
Expand Down
2 changes: 1 addition & 1 deletion components/load_from_files/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def get_filesystem(path_uri: str) -> fsspec.spec.AbstractFileSystem | None:
class LoadFromFiles(DaskLoadComponent):
"""Component that loads datasets from files."""

def __init__(self, *_, directory_uri: str, **kwargs) -> None:
def __init__(self, *_, directory_uri: str) -> None:
self.directory_uri = directory_uri

def load(self) -> dd.DataFrame:
Expand Down
6 changes: 0 additions & 6 deletions components/load_from_hf_hub/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import dask.dataframe as dd
import pandas as pd
from fondant.component import DaskLoadComponent
from fondant.core.schema import Field

logger = logging.getLogger(__name__)

Expand All @@ -17,17 +16,14 @@ class LoadFromHubComponent(DaskLoadComponent):
def __init__(
self,
*,
produces: t.Dict[str, Field],
dataset_name: str,
column_name_mapping: t.Optional[dict],
image_column_names: t.Optional[list],
n_rows_to_load: t.Optional[int],
index_column: t.Optional[str],
**kwargs,
) -> None:
"""
Args:
produces: The schema the component should produce
dataset_name: name of the dataset to load.
column_name_mapping: Mapping of the consumed hub dataset to fondant column names
image_column_names: A list containing the original hub image column names. Used to
Expand All @@ -36,14 +32,12 @@ def __init__(
testing pipeline runs on a small scale.
index_column: Column to set index to in the load component, if not specified a default
globally unique index will be set.
kwargs: Unhandled keyword arguments passed in by Fondant.
"""
self.dataset_name = dataset_name
self.column_name_mapping = column_name_mapping
self.image_column_names = image_column_names
self.n_rows_to_load = n_rows_to_load
self.index_column = index_column
self.produces = produces

def get_columns_to_keep(self) -> t.List[str]:
# Only read required columns
Expand Down
5 changes: 0 additions & 5 deletions components/load_from_parquet/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import dask.dataframe as dd
import pandas as pd
from fondant.component import DaskLoadComponent
from fondant.core.schema import Field

logger = logging.getLogger(__name__)

Expand All @@ -17,12 +16,10 @@ class LoadFromParquet(DaskLoadComponent):
def __init__(
self,
*,
produces: t.Dict[str, Field],
dataset_uri: str,
column_name_mapping: t.Optional[dict],
n_rows_to_load: t.Optional[int],
index_column: t.Optional[str],
**kwargs,
) -> None:
"""
Args:
Expand All @@ -33,13 +30,11 @@ def __init__(
testing pipeline runs on a small scale.
index_column: Column to set index to in the load component, if not specified a default
globally unique index will be set.
kwargs: Unhandled keyword arguments passed in by Fondant.
"""
self.dataset_uri = dataset_uri
self.column_name_mapping = column_name_mapping
self.n_rows_to_load = n_rows_to_load
self.index_column = index_column
self.produces = produces

def get_columns_to_keep(self) -> t.List[str]:
# Only read required columns
Expand Down
6 changes: 0 additions & 6 deletions components/load_from_pdf/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,21 @@
import fsspec as fs
import pandas as pd
from fondant.component import DaskLoadComponent
from fondant.core.schema import Field

logger = logging.getLogger(__name__)


class PDFReader(DaskLoadComponent):
def __init__(
self,
produces: t.Dict[str, Field],
*,
pdf_path: str,
n_rows_to_load: t.Optional[int] = None,
index_column: t.Optional[str] = None,
n_partitions: t.Optional[int] = None,
**kwargs,
) -> None:
"""
Args:
produces: The schema the component should produce
pdf_path: Path to the PDF file
n_rows_to_load: optional argument that defines the number of rows to load.
Useful for testing pipeline runs on a small scale.
Expand All @@ -34,9 +30,7 @@ def __init__(
n_partitions: Number of partitions of the dask dataframe. If not specified, the number
of partitions will be equal to the number of CPU cores. Set to high values if
the data is large and the pipeline is running out of memory.
kwargs: Unhandled keyword arguments passed in by Fondant.
"""
self.produces = produces
self.pdf_path = pdf_path
self.n_rows_to_load = n_rows_to_load
self.index_column = index_column
Expand Down
1 change: 0 additions & 1 deletion components/normalize_text/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(
normalize_lines: bool,
do_lowercase: bool,
remove_punctuation: bool,
**kwargs,
):
self.remove_additional_whitespaces = remove_additional_whitespaces
self.apply_nfc = apply_nfc
Expand Down
2 changes: 1 addition & 1 deletion components/resize_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class ResizeImagesComponent(PandasTransformComponent):
"""Component that resizes images based on a given width and height."""

def __init__(self, *, resize_width: int, resize_height: int, **kwargs) -> None:
def __init__(self, *, resize_width: int, resize_height: int) -> None:
self.resize_width = resize_width
self.resize_height = resize_height

Expand Down
2 changes: 0 additions & 2 deletions components/retrieve_from_weaviate/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ def __init__(
hybrid_query: t.Optional[str],
hybrid_alpha: t.Optional[float],
rerank: bool,
**kwargs,
) -> None:
"""
Args:
weaviate_url: An argument passed to the component.
class_name: Name of class to query
top_k: Amount of context to return.
kwargs: Unhandled keyword arguments passed in by Fondant.
additional_config: Additional configuration passed to the weaviate client.
additional_headers: Additional headers passed to the weaviate client.
hybrid_query: The hybrid query to be used for retrieval. Optional parameter.
Expand Down
2 changes: 0 additions & 2 deletions components/retrieve_laion_by_embedding/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
num_images: int,
aesthetic_score: int,
aesthetic_weight: float,
**kwargs,
) -> None:
"""
Expand All @@ -30,7 +29,6 @@ def __init__(
between 0 and 9.
aesthetic_weight: weight of the aesthetic embedding to add to the query,
between 0 and 1.
kwargs: Unhandled keyword arguments passed in by Fondant.
"""
self.client = ClipClient(
url="https://knn.laion.ai/knn-service",
Expand Down
Loading

0 comments on commit d97246d

Please sign in to comment.