Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: fixing pylint issues #8610

Merged
merged 7 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class TransformersZeroShotDocumentClassifier:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
model: str,
labels: List[str],
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/embedders/azure_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AzureOpenAIDocumentEmbedder:
```
"""

def __init__( # noqa: PLR0913 (too-many-arguments)
def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/embedders/azure_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AzureOpenAITextEmbedder:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/embedders/openai_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class OpenAITextEmbedder:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "text-embedding-ada-002",
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/evaluators/context_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ContextRelevanceEvaluator(LLMEvaluator):
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
examples: Optional[List[Dict[str, Any]]] = None,
progress_bar: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/evaluators/faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class FaithfulnessEvaluator(LLMEvaluator):
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
examples: Optional[List[Dict[str, Any]]] = None,
progress_bar: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/evaluators/llm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class LLMEvaluator:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
instructions: str,
inputs: List[Tuple[str, Type[List]]],
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
"""

# pylint: disable=super-init-not-called
def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
"""

# pylint: disable=super-init-not-called
def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
Expand Down
4 changes: 2 additions & 2 deletions haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class HuggingFaceLocalChatGenerator:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
model: str = "HuggingFaceH4/zephyr-7b-beta",
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
Expand Down Expand Up @@ -295,7 +295,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
]
return {"replies": chat_messages}

def create_message(
def create_message( # pylint: disable=too-many-positional-arguments
self,
text: str,
index: int,
Expand Down
5 changes: 3 additions & 2 deletions haystack/components/rankers/meta_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class MetaFieldRanker:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
meta_field: str,
weight: float = 1.0,
Expand Down Expand Up @@ -106,6 +106,7 @@ def __init__(

def _validate_params(
self,
*,
weight: float,
top_k: Optional[int],
ranking_mode: Literal["reciprocal_rank_fusion", "linear_score"],
Expand Down Expand Up @@ -156,7 +157,7 @@ def _validate_params(
)

@component.output_types(documents=List[Document])
def run(
def run( # pylint: disable=too-many-positional-arguments
self,
documents: List[Document],
top_k: Optional[int] = None,
Expand Down
34 changes: 24 additions & 10 deletions haystack/components/readers/extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ExtractiveReader:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
device: Optional[ComponentDevice] = None,
Expand Down Expand Up @@ -192,8 +192,9 @@ def warm_up(self):
)
self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map))

@staticmethod
def _flatten_documents(
self, queries: List[str], documents: List[List[Document]]
queries: List[str], documents: List[List[Document]]
) -> Tuple[List[str], List[Document], List[int]]:
"""
Flattens queries and Documents so all query-document pairs are arranged along one batch axis.
Expand All @@ -203,8 +204,8 @@ def _flatten_documents(
query_ids = [i for i, documents_ in enumerate(documents) for _ in documents_]
return flattened_queries, flattened_documents, query_ids

def _preprocess(
self, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int
def _preprocess( # pylint: disable=too-many-positional-arguments
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
self, *, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", List["Encoding"], List[int], List[int]]:
"""
Splits and tokenizes Documents and preserves structures by returning mappings to query and Document IDs.
Expand Down Expand Up @@ -256,6 +257,7 @@ def _preprocess(

def _postprocess(
self,
*,
start: "torch.Tensor",
end: "torch.Tensor",
sequence_ids: "torch.Tensor",
Expand Down Expand Up @@ -285,9 +287,9 @@ def _postprocess(
masked_logits = torch.where(mask, logits, -torch.inf)
probabilities = torch.sigmoid(masked_logits * self.calibration_factor)

flat_probabilities = probabilities.flatten(-2, -1) # necessary for topk
flat_probabilities = probabilities.flatten(-2, -1) # necessary for top-k

# topk can return invalid candidates as well if answers_per_seq > num_valid_candidates
# top-k can return invalid candidates as well if answers_per_seq > num_valid_candidates
# We only keep probability > 0 candidates later on
candidates = torch.topk(flat_probabilities, answers_per_seq)
seq_length = logits.shape[-1]
Expand Down Expand Up @@ -343,6 +345,7 @@ def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer:

def _nest_answers(
self,
*,
start: List[List[int]],
end: List[List[int]],
probabilities: "torch.Tensor",
Expand Down Expand Up @@ -526,7 +529,7 @@ def deduplicate_by_overlap(
return deduplicated_answers

@component.output_types(answers=List[ExtractedAnswer])
def run(
def run( # pylint: disable=too-many-positional-arguments
self,
query: str,
documents: List[Document],
Expand Down Expand Up @@ -594,9 +597,15 @@ def run(
no_answer = no_answer if no_answer is not None else self.no_answer
overlap_threshold = overlap_threshold or self.overlap_threshold

flattened_queries, flattened_documents, query_ids = self._flatten_documents(queries, nested_documents)
flattened_queries, flattened_documents, query_ids = ExtractiveReader._flatten_documents(
queries, nested_documents
)
input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess(
flattened_queries, flattened_documents, max_seq_length, query_ids, stride
queries=flattened_queries,
documents=flattened_documents,
max_seq_length=max_seq_length,
query_ids=query_ids,
stride=stride,
)

num_batches = math.ceil(input_ids.shape[0] / max_batch_size) if max_batch_size else 1
Expand Down Expand Up @@ -625,7 +634,12 @@ def run(
end_logits = torch.cat(end_logits_list)

start, end, probabilities = self._postprocess(
start_logits, end_logits, sequence_ids, attention_mask, answers_per_seq, encodings
start=start_logits,
end=end_logits,
sequence_ids=sequence_ids,
attention_mask=attention_mask,
answers_per_seq=answers_per_seq,
encodings=encodings,
)

answers = self._nest_answers(
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/routers/transformers_text_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class TransformersTextRouter:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
model: str,
labels: Optional[List[str]] = None,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/routers/zero_shot_text_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class TransformersZeroShotTextRouter:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
labels: List[str],
multi_label: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions haystack/document_stores/in_memory/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class InMemoryDocumentStore:
Stores data in-memory. It's ephemeral and cannot be saved to disk.
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
Expand Down Expand Up @@ -541,7 +541,7 @@ def bm25_retrieval(

return return_documents

def embedding_retrieval(
def embedding_retrieval( # pylint: disable=too-many-positional-arguments
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion haystack/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def patch_make_records_to_use_kwarg_string_interpolation(original_make_records:
"""A decorator to ensure string interpolation is used."""

@functools.wraps(original_make_records)
def _wrapper(name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None) -> Any:
def _wrapper(name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None) -> Any: # pylint: disable=too-many-positional-arguments
safe_extra = extra or {}
try:
interpolated_msg = msg.format(**safe_extra)
Expand Down
2 changes: 1 addition & 1 deletion haystack/testing/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def to_dict(self) -> Dict[str, Any]:
return cls


def component_class(
def component_class( # pylint: disable=too-many-positional-arguments
name: str,
input_types: Optional[Dict[str, Any]] = None,
output_types: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optio
return model_kwargs


def resolve_hf_pipeline_kwargs(
def resolve_hf_pipeline_kwargs( # pylint: disable=too-many-positional-arguments
huggingface_pipeline_kwargs: Dict[str, Any],
model: str,
task: Optional[str],
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ max-locals = 45 # Default is 15
max-module-lines = 2468 # Default is 1000
max-nested-blocks = 9 # Default is 5
max-statements = 206 # Default is 50

[tool.pylint.'SIMILARITIES']
min-similarity-lines = 6

Expand Down
15 changes: 12 additions & 3 deletions test/components/readers/test_extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_flatten_documents(mock_reader: ExtractiveReader):

def test_preprocess(mock_reader: ExtractiveReader):
_, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess(
example_queries * 3, example_documents[0], 384, [1, 1, 1], 0
queries=example_queries * 3, documents=example_documents[0], max_seq_length=384, query_ids=[1, 1, 1], stride=0
)
expected_seq_ids = torch.full((3, 384), -1, dtype=torch.int)
expected_seq_ids[:, :16] = 0
Expand All @@ -333,7 +333,11 @@ def test_preprocess(mock_reader: ExtractiveReader):

def test_preprocess_splitting(mock_reader: ExtractiveReader):
_, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess(
example_queries * 4, example_documents[0] + [Document(content="a" * 64)], 96, [1, 1, 1, 1], 0
queries=example_queries * 4,
documents=example_documents[0] + [Document(content="a" * 64)],
max_seq_length=96,
query_ids=[1, 1, 1, 1],
stride=0,
)
assert seq_ids.shape[0] == 5
assert query_ids == [1, 1, 1, 1, 1]
Expand Down Expand Up @@ -362,7 +366,12 @@ def test_postprocess(mock_reader: ExtractiveReader):
encoding.token_to_chars = lambda i: (int(i), int(i) + 1)

start_candidates, end_candidates, probs = mock_reader._postprocess(
start, end, sequence_ids, attention_mask, 3, [encoding, encoding]
start=start,
end=end,
sequence_ids=sequence_ids,
attention_mask=attention_mask,
answers_per_seq=3,
encodings=[encoding, encoding],
)

assert len(start_candidates) == len(end_candidates) == len(probs) == 2
Expand Down
Loading