diff --git a/docs/docs/reference/interpretations.md b/docs/docs/reference/interpretations.md index 6bb0efc52..8875e7af4 100644 --- a/docs/docs/reference/interpretations.md +++ b/docs/docs/reference/interpretations.md @@ -50,5 +50,5 @@ The switch interpretation allows you to define multiple interpretations that wil | Parameter Name | Required? | Type | Description | |---------------- |----------- |------------ |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | switch_on | Yes | ValueProvider | The value provider that will be evaluated for each source node. The value of the value provider will be used to determine which interpretation to apply. | -| interpretations | Yes | Dictionary | Contains the interpretations that will be applied. The keys represent the values of the `switch_on` parameter. The values represent the interpretations that will be applied. | +| interpretations | Yes | Dictionary | Contains the interpretations that will be applied. The keys represent the values of the `switch_on` parameter. The values represent the interpretations that will be applied. Each value may also be a list of interpretations. | | default | No | Dictionary | Contains the default interpretation that will be applied if no interpretation has the same value as the value of the `switch_on` parameter. | diff --git a/nodestream/interpreting/interpretations/switch_interpretation.py b/nodestream/interpreting/interpretations/switch_interpretation.py index 6781d276e..e311ba953 100644 --- a/nodestream/interpreting/interpretations/switch_interpretation.py +++ b/nodestream/interpreting/interpretations/switch_interpretation.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, List from ...pipeline.value_providers import ( ProviderContext, @@ -29,6 +29,16 @@ class SwitchInterpretation( "fail_on_unhandled", ) + @staticmethod + def guarantee_interpretation_list_from_file_data(file_data) -> List[Interpretation]: + if isinstance(file_data, list): + return [ + Interpretation.from_file_data(**interpretation) + for interpretation in file_data + ] + + return [Interpretation.from_file_data(**file_data)] + def __init__( self, switch_on: StaticValueOrValueProvider, @@ -39,28 +49,34 @@ def __init__( ): self.switch_on = ValueProvider.guarantee_value_provider(switch_on) self.interpretations = { - field_value: Interpretation.from_file_data(**interpretation) + field_value: self.guarantee_interpretation_list_from_file_data( + interpretation + ) for field_value, interpretation in cases.items() } - self.default = Interpretation.from_file_data(**default) if default else None + self.default = ( + self.guarantee_interpretation_list_from_file_data(default) + if default + else None + ) self.normalization = normalization or {} self.fail_on_unhandled = fail_on_unhandled def all_subordinate_components(self): - yield from self.interpretations.values() + for child in self.interpretations.values(): + yield from child if self.default: - yield self.default + yield from self.default def interpret(self, context: ProviderContext): - value_to_look_for = self.switch_on.normalize_single_value( - context, **self.normalization - ) - if value_to_look_for not in self.interpretations: - if self.default: - self.default.interpret(context) - return + key = self.switch_on.normalize_single_value(context, **self.normalization) + interpretations = self.interpretations.get(key, self.default) + + if interpretations is None: if self.fail_on_unhandled: - raise UnhandledBranchError(value_to_look_for) + raise UnhandledBranchError(key) else: - return - return self.interpretations[value_to_look_for].interpret(context) + interpretations = [] + + for interpretation in interpretations: + interpretation.interpret(context) diff --git a/tests/unit/interpreting/interpretations/test_switch_interpretation.py b/tests/unit/interpreting/interpretations/test_switch_interpretation.py index b9e4943f3..b7f624414 100644 --- a/tests/unit/interpreting/interpretations/test_switch_interpretation.py +++ b/tests/unit/interpreting/interpretations/test_switch_interpretation.py @@ -42,3 +42,15 @@ def test_missing_without_default_without_error(blank_context): properties = blank_context.desired_ingest.source.properties assert_that(properties, not_(has_entry("success", True))) assert_that(properties, not_(has_entry("random", True))) + + +def test_switch_with_multiple_interpretations(blank_context): + subject = SwitchInterpretation( + switch_on="foo", + cases={"foo": [INTERPRETATION_USED_AS_HIT, INTERPRETATION_FOR_RANDOM]}, + fail_on_unhandled=False, + ) + subject.interpret(blank_context) + properties = blank_context.desired_ingest.source.properties + assert_that(properties, has_entry("success", True)) + assert_that(properties, has_entry("random", True))