Skip to content

Commit

Permalink
Add optional_variables in ConditionalRouter
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Nov 22, 2024
1 parent 5cc2df1 commit d0c1d14
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
52 changes: 51 additions & 1 deletion haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
custom_filters: Optional[Dict[str, Callable]] = None,
unsafe: bool = False,
validate_output_type: bool = False,
optional_variables: Optional[List[str]] = None,
):
"""
Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
Expand All @@ -136,11 +137,49 @@ def __init__(
:param validate_output_type:
Enable validation of routes' output.
If a route output doesn't match the declared type a ValueError is raised running.
:param optional_variables:
A list of variable names that are optional in your route conditions and outputs.
If these variables are not provided at runtime, they will be set to `None`.
This allows you to write routes that can handle missing inputs gracefully without raising errors.
Example usage with a default fallback route:
```python
routes = [
{
"condition": '{{ path == "rag" }}',
"output": "{{ question }}",
"output_name": "rag_route",
"output_type": str
},
{
"condition": "{{ True }}", # fallback route
"output": "{{ question }}",
"output_name": "default_route",
"output_type": str
}
]
router = ConditionalRouter(routes, optional_variables=["path"])
# When 'path' is provided, specific route is taken:
result = router.run(question="What?", path="rag")
assert result == {"rag_route": "What?"}
# When 'path' is not provided, fallback route is taken:
result = router.run(question="What?")
assert result == {"default_route": "What?"}
```
This pattern is particularly useful when:
- You want to provide default/fallback behavior when certain inputs are missing
- Some variables are only needed for specific routing conditions
- You're building flexible pipelines where not all inputs are guaranteed to be present
"""
self.routes: List[dict] = routes
self.custom_filters = custom_filters or {}
self._unsafe = unsafe
self._validate_output_type = validate_output_type
self.optional_variables = optional_variables or []

# Create a Jinja environment to inspect variables in the condition templates
if self._unsafe:
Expand All @@ -166,7 +205,17 @@ def __init__(
# extract outputs
output_types.update({route["output_name"]: route["output_type"]})

component.set_input_types(self, **{var: Any for var in input_types})
# remove optional variables from mandatory input types
mandatory_input_types = input_types - set(self.optional_variables)

# add mandatory input types
component.set_input_types(self, **{var: Any for var in mandatory_input_types})

# now add optional input types
for optional_var_name in self.optional_variables:
component.set_input_type(self, name=optional_var_name, type=Any, default=None)

# set output types
component.set_output_types(self, **output_types)

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -186,6 +235,7 @@ def to_dict(self) -> Dict[str, Any]:
custom_filters=se_filters,
unsafe=self._unsafe,
validate_output_type=self._validate_output_type,
optional_variables=self.optional_variables,
)

@classmethod
Expand Down
48 changes: 48 additions & 0 deletions test/components/routers/test_conditional_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,51 @@ def test_validate_output_type_with_unsafe(self):
streams = ["1", "2", "3", "4"]
with pytest.raises(ValueError, match="Route 'streams' type doesn't match expected type"):
router.run(streams=streams, message=message)

def test_router_with_optional_parameters(self):
"""
Test that the router works with optional parameters, particularly testing the default/fallback route
when an expected parameter is not provided.
"""
routes = [
{"condition": '{{path == "rag"}}', "output": "{{question}}", "output_name": "normal", "output_type": str},
{
"condition": '{{path == "followup_short"}}',
"output": "{{question}}",
"output_name": "followup_short",
"output_type": str,
},
{
"condition": '{{path == "followup_elaborate"}}',
"output": "{{question}}",
"output_name": "followup_elaborate",
"output_type": str,
},
{"condition": "{{ True }}", "output": "{{ question }}", "output_name": "fallback", "output_type": str},
]

router = ConditionalRouter(routes, optional_variables=["path"])

# Test direct component usage
result = router.run(question="What?")
assert result == {"fallback": "What?"}, "Default route should be taken when 'path' is not provided"

# Test with path parameter
result = router.run(question="What?", path="rag")
assert result == {"normal": "What?"}, "Specific route should be taken when 'path' is provided"

# Test in pipeline
from haystack import Pipeline

pipe = Pipeline()
pipe.add_component("router", router)

# Test pipeline without path parameter
result = pipe.run(data={"router": {"question": "What?"}})
assert result["router"] == {
"fallback": "What?"
}, "Default route should work in pipeline when 'path' is not provided"

# Test pipeline with path parameter
result = pipe.run(data={"router": {"question": "What?", "path": "followup_short"}})
assert result["router"] == {"followup_short": "What?"}, "Specific route should work in pipeline"

0 comments on commit d0c1d14

Please sign in to comment.