Skip to content

Commit

Permalink
fix: Fix Pipeline.run() getting stuck in a loop even though there a…
Browse files Browse the repository at this point in the history
…re components that can run (#7434)
  • Loading branch information
silvanocerza authored Mar 28, 2024
1 parent 6fcb62a commit 6e28969
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
13 changes: 8 additions & 5 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,26 +897,29 @@ def run(self, word: str):
and last_waiting_for_input is not None
and before_last_waiting_for_input == last_waiting_for_input
):
# Are we actually stuck or there's a lazy variadic waiting for input?
# This is our last resort, if there's no lazy variadic waiting for input
# Are we actually stuck or there's a lazy variadic or a component with has only default inputs waiting for input?
# This is our last resort, if there's no lazy variadic or component with only default inputs waiting for input
# we're stuck for real and we can't make any progress.
for name, comp in waiting_for_input:
is_variadic = any(socket.is_variadic for socket in comp.__haystack_input__._sockets_dict.values()) # type: ignore
if is_variadic and not comp.__haystack_is_greedy__: # type: ignore[attr-defined]
has_only_defaults = all(
not socket.is_mandatory for socket in comp.__haystack_input__._sockets_dict.values() # type: ignore
)
if is_variadic and not comp.__haystack_is_greedy__ or has_only_defaults: # type: ignore[attr-defined]
break
else:
# We're stuck in a loop for real, we can't make any progress.
# BAIL!
break

if len(waiting_for_input) == 1:
# We have a single component with variadic input waiting for input.
# We have a single component with variadic input or only default inputs waiting for input.
# If we're at this point it means it has been waiting for input for at least 2 iterations.
# This will never run.
# BAIL!
break

# There was a lazy variadic waiting for input, we can run it
# There was a lazy variadic or a component with only default waiting for input, we can run it
waiting_for_input.remove((name, comp))
to_run.append((name, comp))
continue
Expand Down
37 changes: 37 additions & 0 deletions test/core/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from haystack import Document
from haystack.components.builders import PromptBuilder
from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.components.others import Multiplexer
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.core.component import component
Expand Down Expand Up @@ -807,3 +808,39 @@ def test_correct_execution_order_of_components_with_only_defaults(spying_tracer)
"Question: What is the capital of France?"
}
}


def test_pipeline_is_not_stuck_with_components_with_only_defaults():
FakeGenerator = component_class(
"FakeGenerator", input_types={"prompt": str}, output_types={"replies": List[str]}, output={"replies": ["Paris"]}
)
docs = [Document(content="Rome is the capital of Italy"), Document(content="Paris is the capital of France")]
doc_store = InMemoryDocumentStore()
doc_store.write_documents(docs)
template = (
"Given the following information, answer the question.\n"
"Context:\n"
"{% for document in documents %}"
" {{ document.content }}\n"
"{% endfor %}"
"Question: {{ query }}"
)

pipe = Pipeline()

pipe.add_component("retriever", InMemoryBM25Retriever(document_store=doc_store))
pipe.add_component("prompt_builder", PromptBuilder(template=template))
pipe.add_component("generator", FakeGenerator())
pipe.add_component("answer_builder", AnswerBuilder())

pipe.connect("retriever", "prompt_builder.documents")
pipe.connect("prompt_builder.prompt", "generator.prompt")
pipe.connect("generator.replies", "answer_builder.replies")
pipe.connect("retriever.documents", "answer_builder.documents")

query = "What is the capital of France?"
res = pipe.run({"query": query})
assert len(res) == 1
answers = res["answer_builder"]["answers"]
assert len(answers) == 1
assert answers[0].data == "Paris"

0 comments on commit 6e28969

Please sign in to comment.