diff --git a/.gitignore b/.gitignore index 43edc0a6..45da9a00 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,3 @@ htmlcov tests/TestPass.csv tests/parsed_log.csv _trial_marker -/test.py \ No newline at end of file diff --git a/core/basic_models/actions/basic_actions.py b/core/basic_models/actions/basic_actions.py index 0f158985..f84dd953 100644 --- a/core/basic_models/actions/basic_actions.py +++ b/core/basic_models/actions/basic_actions.py @@ -1,6 +1,7 @@ # coding: utf-8 +import asyncio import random -from typing import Union, Dict, List, Any, Optional, AsyncGenerator +from typing import Union, Dict, List, Any, Optional import core.logging.logger_constants as log_const from core.basic_models.actions.command import Command @@ -32,9 +33,8 @@ def __init__(self, items: Optional[Dict[str, Any]] = None, id: Optional[str] = N self.version = items.get("version", -1) async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: raise NotImplementedError - yield def on_run_error(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser): log("exc_handler: Action failed to run. Return None. MESSAGE: %(masked_message)s.", user, @@ -72,8 +72,11 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.nodes = items.get("nodes") or {} async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: - yield Command(self.command, self.nodes, self.id, request_type=self.request_type, request_data=self.request_data) + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] + commands.append(Command(self.command, self.nodes, self.id, request_type=self.request_type, + request_data=self.request_data)) + return commands class RequirementAction(Action): @@ -102,10 +105,11 @@ def build_internal_item(self) -> str: return self._item async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] if self.requirement.check(text_preprocessing_result, user, params): - async for command in self.internal_item.run(user, text_preprocessing_result, params): - yield command + commands.extend(await self.internal_item.run(user, text_preprocessing_result, params) or []) + return commands class ChoiceAction(Action): @@ -137,18 +141,18 @@ def build_else_item(self) -> Optional[str]: return self._else_item async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] choice_is_made = False for item in self.items: checked = item.requirement.check(text_preprocessing_result, user, params) if checked: - async for command in item.internal_item.run(user, text_preprocessing_result, params): - yield command + commands.extend(await item.internal_item.run(user, text_preprocessing_result, params) or []) choice_is_made = True break if not choice_is_made and self._else_item: - async for command in self.else_item.run(user, text_preprocessing_result, params): - yield command + commands.extend(await self.else_item.run(user, text_preprocessing_result, params) or []) + return commands class ElseAction(Action): @@ -185,16 +189,14 @@ def build_item(self) -> str: def build_else_item(self) -> Optional[str]: return self._else_item - async def run( - self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Optional[Dict[str, Union[str, float, int]]]] = None - ) -> AsyncGenerator[Command, None]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Optional[Dict[str, Union[str, float, int]]]] = None) -> List[Command]: + commands = [] if self.requirement.check(text_preprocessing_result, user, params): - async for command in self.item.run(user, text_preprocessing_result, params): - yield command + commands.extend(await self.item.run(user, text_preprocessing_result, params) or []) elif self._else_item: - async for command in self.else_item.run(user, text_preprocessing_result, params): - yield command + commands.extend(await self.else_item.run(user, text_preprocessing_result, params) or []) + return commands class ActionOfActions(Action): @@ -213,10 +215,11 @@ def build_actions(self) -> List[Action]: class CompositeAction(ActionOfActions): async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] for action in self.actions: - async for command in action.run(user, text_preprocessing_result, params): - yield command + commands.extend(await action.run(user, text_preprocessing_result, params) or []) + return commands class NonRepeatingAction(ActionOfActions): @@ -228,7 +231,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self._last_action_ids_storage = items["last_action_ids_storage"] async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] last_ids = user.last_action_ids[self._last_action_ids_storage] all_indexes = list(range(self._actions_count)) max_last_ids_count = self._actions_count - 1 @@ -238,8 +242,8 @@ async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreproces action_index = random.choice(available_indexes) action = self.actions[action_index] last_ids.add(action_index) - async for command in action.run(user, text_preprocessing_result, params): - yield command + commands.extend(await action.run(user, text_preprocessing_result, params) or []) + return commands class RandomAction(Action): @@ -255,8 +259,9 @@ def build_actions(self) -> List[Action]: return self._raw_actions async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] pos = random.randint(0, len(self._raw_actions) - 1) action = self.actions[pos] - async for command in action.run(user, text_preprocessing_result, params=params): - yield command + commands.extend(await action.run(user, text_preprocessing_result, params=params) or []) + return commands diff --git a/core/basic_models/actions/client_profile_actions.py b/core/basic_models/actions/client_profile_actions.py index 0289d4df..b3093f25 100644 --- a/core/basic_models/actions/client_profile_actions.py +++ b/core/basic_models/actions/client_profile_actions.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional, Union, AsyncGenerator +from typing import Dict, Any, Optional, Union, List from core.basic_models.actions.command import Command from core.basic_models.actions.string_actions import StringAction @@ -73,15 +73,15 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.request_data[KAFKA_REPLY_TOPIC] = config["template_settings"]["consumer_topic"] async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: if self.behavior: callback_id = user.message.generate_new_callback_id() scenario_id = user.last_scenarios.last_scenario_name if hasattr(user, 'last_scenarios') else None user.behaviors.add(callback_id, self.behavior, scenario_id, text_preprocessing_result.raw, pickle_deepcopy(params)) - async for command in super().run(user, text_preprocessing_result, params): - yield command + commands = await super().run(user, text_preprocessing_result, params) + return commands class RememberThisAction(StringAction): @@ -157,7 +157,7 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): }) async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: self._nodes.update({ "consumer": { "projectId": user.settings["template_settings"]["project_id"] @@ -174,5 +174,5 @@ async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessing if REPLY_TOPIC_KEY not in self.request_data and KAFKA_REPLY_TOPIC not in self.request_data: self.request_data[KAFKA_REPLY_TOPIC] = user.settings["template_settings"]["consumer_topic"] - async for command in super().run(user, text_preprocessing_result, params): - yield command + commands = await super().run(user, text_preprocessing_result, params) + return commands diff --git a/core/basic_models/actions/counter_actions.py b/core/basic_models/actions/counter_actions.py index a1b1ea4c..70032173 100644 --- a/core/basic_models/actions/counter_actions.py +++ b/core/basic_models/actions/counter_actions.py @@ -1,5 +1,5 @@ # coding: utf-8 -from typing import Union, Dict, Any, Optional, AsyncGenerator +from typing import Union, Dict, Any, Optional, List from core.basic_models.actions.basic_actions import Action from core.basic_models.actions.command import Command @@ -22,26 +22,26 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): class CounterIncrementAction(CounterAction): async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] user.counters[self.key].inc(self.value, self.lifetime) - return - yield + return commands class CounterDecrementAction(CounterAction): async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] user.counters[self.key].dec(-self.value, self.lifetime) - return - yield + return commands class CounterClearAction(CounterAction): async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] user.counters.clear(self.key) - return - yield + return commands class CounterSetAction(CounterAction): @@ -58,10 +58,10 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.time_shift = items.get("time_shift", 0) async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] user.counters[self.key].set(self.value, self.reset_time, self.time_shift) - return - yield + return commands class CounterCopyAction(Action): @@ -73,8 +73,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.time_shift = items.get("time_shift", 0) async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] value = user.counters[self.src].value user.counters[self.dst].set(value, self.reset_time, self.time_shift) - return - yield + return commands diff --git a/core/basic_models/actions/external_actions.py b/core/basic_models/actions/external_actions.py index 5911f4e2..051bfff2 100644 --- a/core/basic_models/actions/external_actions.py +++ b/core/basic_models/actions/external_actions.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Any, Union, AsyncGenerator +from typing import Optional, Dict, Any, Union, List from core.basic_models.actions.basic_actions import CommandAction, Action from core.basic_models.actions.basic_actions import action_factory @@ -21,7 +21,7 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self._action_key: str = items["action"] async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: action: Action = user.descriptions["external_actions"][self._action_key] - async for command in action.run(user, text_preprocessing_result, params): - yield command + commands = await action.run(user, text_preprocessing_result, params) + return commands diff --git a/core/basic_models/actions/push_action.py b/core/basic_models/actions/push_action.py index f7c894ac..b0b397f1 100644 --- a/core/basic_models/actions/push_action.py +++ b/core/basic_models/actions/push_action.py @@ -1,7 +1,7 @@ # coding: utf-8 import base64 import uuid -from typing import Union, Dict, Any, Optional, AsyncGenerator +from typing import Union, Dict, List, Any, Optional from core.basic_models.actions.command import Command from core.basic_models.actions.string_actions import StringAction @@ -69,7 +69,7 @@ def _render_request_data(self, action_params): return request_data async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: params = params or {} command_params = { "projectId": user.settings["template_settings"]["project_id"], @@ -78,8 +78,9 @@ async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessing "content": self._generate_command_context(user, text_preprocessing_result, params), } requests_data = self._render_request_data(params) - yield Command(self.command, command_params, self.id, request_type=self.request_type, - request_data=requests_data, need_payload_wrap=False, need_message_name=False) + commands = [Command(self.command, command_params, self.id, request_type=self.request_type, + request_data=requests_data, need_payload_wrap=False, need_message_name=False)] + return commands class PushAuthenticationActionHttp(PushAction): @@ -132,12 +133,11 @@ def _create_authorization_token(self, items: Dict[str, Any]) -> str: return authorization_token async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: params = params or {} collected = user.parametrizer.collect(text_preprocessing_result, filter_params={"command": self.command}) params.update(collected) - async for command in self.http_request_action.run(user, text_preprocessing_result, params): - yield command + return await self.http_request_action.run(user, text_preprocessing_result, params) class GetRuntimePermissionsAction(PushAction): @@ -167,7 +167,7 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.command = GET_RUNTIME_PERMISSIONS async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: params = params or {} scenario_id = user.last_scenarios.last_scenario_name user.behaviors.add(user.message.generate_new_callback_id(), self.behavior, scenario_id, @@ -183,8 +183,9 @@ async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessing } } command_params = self._generate_command_context(user, text_preprocessing_result, params) - yield Command(self.command, command_params, self.id, request_type=self.request_type, - request_data=self.request_data, need_payload_wrap=False, need_message_name=False) + commands = [Command(self.command, command_params, self.id, request_type=self.request_type, + request_data=self.request_data, need_payload_wrap=False, need_message_name=False)] + return commands class PushActionHttp(PushAction): @@ -309,7 +310,7 @@ def _create_instance_of_http_request_action(self, items: Dict[str, Any], id: Opt self.http_request_action = HTTPRequestAction(items, id) async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: params = params or {} collected = user.parametrizer.collect(text_preprocessing_result, filter_params={"command": self.command}) params.update(collected) @@ -330,5 +331,4 @@ async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessing "payload": self.payload } self.http_request_action.method_params["json"] = request_body_parameters - async for command in self.http_request_action.run(user, text_preprocessing_result, params): - yield command + return await self.http_request_action.run(user, text_preprocessing_result, params) diff --git a/core/basic_models/actions/string_actions.py b/core/basic_models/actions/string_actions.py index 52d93c22..234a01fa 100644 --- a/core/basic_models/actions/string_actions.py +++ b/core/basic_models/actions/string_actions.py @@ -1,9 +1,10 @@ # coding: utf-8 +import asyncio import random from copy import copy from functools import cached_property from itertools import chain -from typing import Union, Dict, List, Any, Optional, Tuple, TypeVar, Type, AsyncGenerator +from typing import Union, Dict, List, Any, Optional, Tuple, TypeVar, Type from core.basic_models.actions.basic_actions import CommandAction from core.basic_models.actions.command import Command @@ -87,9 +88,8 @@ def _get_rendered_tree_recursive(self, value: T, params: Dict, no_empty=False) - return result async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: raise NotImplementedError - yield class StringAction(NodeAction): @@ -129,7 +129,7 @@ def _generate_command_context(self, user: BaseUser, text_preprocessing_result: B return command_params async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: # Example: Command("ANSWER_TO_USER", {"answer": {"key1": "string1", "keyN": "stringN"}}) params = params or {} command_params = self._generate_command_context(user, text_preprocessing_result, params) @@ -142,8 +142,9 @@ async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreproces } }) - yield Command(self.command, command_params, self.id, request_type=self.request_type, - request_data=self.request_data) + commands = [Command(self.command, command_params, self.id, request_type=self.request_type, + request_data=self.request_data)] + return commands class AfinaAnswerAction(NodeAction): @@ -165,9 +166,10 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.command: str = ANSWER_TO_USER async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: params = user.parametrizer.collect(text_preprocessing_result, filter_params={"command": self.command}) answer_params = dict() + result = [] nodes = self.nodes.items() if self.nodes else [] for key, template in nodes: @@ -178,8 +180,9 @@ async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreproces answer_params[key] = rendered if answer_params: - yield Command(self.command, answer_params, self.id, request_type=self.request_type, - request_data=self.request_data) + result = [Command(self.command, answer_params, self.id, request_type=self.request_type, + request_data=self.request_data)] + return result class SDKAnswer(NodeAction): @@ -237,7 +240,7 @@ class SDKAnswer(NodeAction): ['suggestions', 'buttons', INDEX_WILDCARD, 'title']] def __init__(self, items: Dict[str, Any], id: Optional[str] = None): - super().__init__(items, id) + super(SDKAnswer, self).__init__(items, id) self.command: str = ANSWER_TO_USER if self._nodes == {}: self._nodes = {i: items.get(i) for i in items if @@ -270,13 +273,22 @@ def do_random(self, input_dict: Union[list, dict]): d[k] = d[k][random_index % len(d[k])] async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + result = [] params = user.parametrizer.collect(text_preprocessing_result, filter_params={"command": self.command}) rendered = self._get_rendered_tree(self.nodes, params, self.no_empty_nodes) self.do_random(rendered) if rendered or not self.no_empty_nodes: - yield Command(self.command, rendered, self.id, - request_type=self.request_type, request_data=self.request_data) + result = [ + Command( + self.command, + rendered, + self.id, + request_type=self.request_type, + request_data=self.request_data, + ) + ] + return result class SDKAnswerToUser(NodeAction): @@ -414,7 +426,9 @@ def build_root(self): return self._root async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + + result = [] params = user.parametrizer.collect(text_preprocessing_result, filter_params={self.COMMAND: self.command}) rendered = self._get_rendered_tree(self.nodes[self.STATIC], params, self.no_empty_nodes) if self._nodes[self.RANDOM_CHOICE]: @@ -440,4 +454,13 @@ async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreproces if part.requirement.check(text_preprocessing_result, user): out.update(part.render(rendered)) if rendered or not self.no_empty_nodes: - yield Command(self.command, out, self.id, request_type=self.request_type, request_data=self.request_data) + result = [ + Command( + self.command, + out, + self.id, + request_type=self.request_type, + request_data=self.request_data, + ) + ] + return result diff --git a/core/basic_models/actions/variable_actions.py b/core/basic_models/actions/variable_actions.py index e4b4f063..a39cddcd 100644 --- a/core/basic_models/actions/variable_actions.py +++ b/core/basic_models/actions/variable_actions.py @@ -1,6 +1,6 @@ import collections import json -from typing import Optional, Dict, Any, Union, AsyncGenerator +from typing import Optional, Dict, Any, Union, List from jinja2 import exceptions as jexcept @@ -29,7 +29,8 @@ def _set(self, user, value): raise NotImplementedError async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] params = user.parametrizer.collect(text_preprocessing_result) try: # if path is wrong, it may fail with UndefinedError @@ -47,8 +48,7 @@ async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreproces value = None self._set(user, value) - return - yield + return commands class SetVariableAction(BaseSetVariableAction): @@ -77,10 +77,10 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.key: str = items["key"] async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] user.variables.delete(self.key) - return - yield + return commands class ClearVariablesAction(Action): @@ -90,10 +90,10 @@ def __init__(self, items: Dict[str, Any] = None, id: Optional[str] = None): super().__init__(items, id) async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] user.variables.clear() - return - yield + return commands class SetLocalVariableAction(BaseSetVariableAction): diff --git a/core/basic_models/scenarios/base_scenario.py b/core/basic_models/scenarios/base_scenario.py index 16f037d0..510e04f6 100644 --- a/core/basic_models/scenarios/base_scenario.py +++ b/core/basic_models/scenarios/base_scenario.py @@ -1,5 +1,5 @@ # coding: utf-8 -from typing import Dict, Any, List, AsyncGenerator +from typing import Dict, Any, List import core.logging.logger_constants as log_const import scenarios.logging.logger_constants as scenarios_log_const @@ -67,40 +67,44 @@ def text_fits(self, text_preprocessing_result, user): return False async def get_no_commands_action(self, user, text_preprocessing_result, - params: Dict[str, Any] = None) -> AsyncGenerator[Command, None]: + params: Dict[str, Any] = None) -> List[Command]: log_params = {log_const.KEY_NAME: scenarios_log_const.CHOSEN_ACTION_VALUE, scenarios_log_const.CHOSEN_ACTION_VALUE: self._empty_answer} log(scenarios_log_const.CHOSEN_ACTION_MESSAGE, user, log_params) try: - async for command in self.empty_answer.run(user, text_preprocessing_result, params): - yield command + empty_answer = await self.empty_answer.run(user, text_preprocessing_result, params) or [] except KeyError: log_params = {log_const.KEY_NAME: scenarios_log_const.CHOSEN_ACTION_VALUE} log("Scenario has empty answer, but empty_answer action isn't defined", params=log_params, level='WARNING') + empty_answer = [] + return empty_answer async def get_action_results(self, user, text_preprocessing_result, - actions: List[Action], params: Dict[str, Any] = None) -> AsyncGenerator[Command, None]: + actions: List[Action], params: Dict[str, Any] = None) -> List[Command]: + results = [] for action in actions: + result = await action.run(user, text_preprocessing_result, params) or [] log_params = self._log_params() log_params["class"] = action.__class__.__name__ log("called action: %(class)s", user, log_params) - async for command in action.run(user, text_preprocessing_result, params): - if command.action_id: - log_params = self._log_params() - log_params["id"] = command.action_id - log("external action id: %(id)s", user, log_params) - log_params = self._log_params() - log_params["name"] = command.name - log("action result name: %(name)s", user, log_params) - yield command + if result: + for command in result: + if command.action_id: + log_params = self._log_params() + log_params["id"] = command.action_id + log("external action id: %(id)s", user, log_params) + + log_params = self._log_params() + log_params["name"] = command.name + log("action result name: %(name)s", user, log_params) + results.extend(result) + return results @property def history(self): return {"scenario_path": [{"scenario": self.id, "node": None}]} - async def run(self, text_preprocessing_result, user, - params: Dict[str, Any] = None) -> AsyncGenerator[Command, None]: - async for command in self.get_action_results(user, text_preprocessing_result, self.actions, params): - yield command + async def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None) -> List[Command]: + return await self.get_action_results(user, text_preprocessing_result, self.actions, params) diff --git a/core/utils/exception_handlers.py b/core/utils/exception_handlers.py index 52c20895..d478f717 100644 --- a/core/utils/exception_handlers.py +++ b/core/utils/exception_handlers.py @@ -1,5 +1,4 @@ import asyncio -import inspect import sys from functools import wraps @@ -29,29 +28,6 @@ async def _wrapper(obj, *args, **kwarg): print(sys.exc_info()) return result - return _wrapper - elif inspect.isasyncgenfunction(funct): - @wraps(funct) - async def _wrapper(obj, *args, **kwarg): - try: - async for yield_ in funct(obj, *args, **kwarg): - yield yield_ - except handled_exceptions: - try: - on_error = ( - getattr(obj, on_error_obj_method_name) - if on_error_obj_method_name else (lambda *x, **y: None) - ) - if asyncio.iscoroutinefunction(on_error): - yield await on_error(*args, **kwarg) - elif inspect.isasyncgenfunction(on_error): - async for yield_ in on_error(*args, **kwarg): - yield yield_ - else: - yield on_error(*args, **kwarg) - except Exception: - print(sys.exc_info()) - return _wrapper else: @wraps(funct) diff --git a/scenarios/actions/action.py b/scenarios/actions/action.py index f69882d8..8715a338 100644 --- a/scenarios/actions/action.py +++ b/scenarios/actions/action.py @@ -1,7 +1,8 @@ +import asyncio import copy import time from functools import cached_property -from typing import Optional, Dict, Any, Union, AsyncGenerator +from typing import Optional, Dict, Any, Union, List from core.basic_models.actions.basic_actions import Action from core.basic_models.actions.command import Command @@ -35,10 +36,10 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.form = items["form"] async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] user.forms.remove_item(self.form) - return - yield + return commands class ClearInnerFormAction(ClearFormAction): @@ -52,12 +53,12 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.inner_form = items["inner_form"] async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] form = user.forms[self.form] if form: form.forms.remove_item(self.inner_form) - return - yield + return commands class RemoveFormFieldAction(Action): @@ -72,11 +73,11 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.field = items["field"] async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] form = user.forms[self.form] form.fields.remove_item(self.field) - return - yield + return commands class RemoveCompositeFormFieldAction(RemoveFormFieldAction): @@ -91,12 +92,12 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.inner_form = items["inner_form"] async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] form = user.forms[self.form] inner_form = form.forms[self.inner_form] inner_form.fields.remove_item(self.field) - return - yield + return commands class BreakScenarioAction(Action): @@ -107,11 +108,11 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.scenario_id = items.get("scenario_id") async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] scenario_id = self.scenario_id if self.scenario_id is not None else user.last_scenarios.last_scenario_name user.scenario_models[scenario_id].set_break() - return - yield + return commands class SaveBehaviorAction(Action): @@ -126,14 +127,14 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.check_scenario = items.get("check_scenario", True) async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] scenario_id = None if self.check_scenario: scenario_id = user.last_scenarios.last_scenario_name user.behaviors.add(user.message.generate_new_callback_id(), self.behavior, scenario_id, text_preprocessing_result.raw, action_params=pickle_deepcopy(params)) - return - yield + return commands class BasicSelfServiceActionWithState(StringAction): @@ -159,17 +160,17 @@ def command_action(self) -> StringAction: def _check(self, user): return not user.behaviors.check_got_saved_id(self.behavior_action.behavior) - async def _run(self, user, text_preprocessing_result, params=None) -> AsyncGenerator[Command, None]: - async for command in self.behavior_action.run(user, text_preprocessing_result, params): - yield command - async for command in self.command_action.run(user, text_preprocessing_result, params): - yield command + async def _run(self, user, text_preprocessing_result, params=None): + await self.behavior_action.run(user, text_preprocessing_result, params) + command_action_result = await self.command_action.run(user, text_preprocessing_result, params) or [] + return command_action_result async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: + commands = [] if self._check(user): - async for command in self._run(user, text_preprocessing_result, params): - yield command + commands.extend(await self._run(user, text_preprocessing_result, params)) + return commands class DeleteVariableAction(Action): @@ -177,27 +178,23 @@ class DeleteVariableAction(Action): key: str def __init__(self, items: Dict[str, Any], id: Optional[str] = None): - super().__init__(items, id) + super(DeleteVariableAction, self).__init__(items, id) self.key: str = items["key"] async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.variables.delete(self.key) - return - yield class ClearVariablesAction(Action): version: Optional[int] def __init__(self, items: Dict[str, Any] = None, id: Optional[str] = None): - super().__init__(items, id) + super(ClearVariablesAction, self).__init__(items, id) async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.variables.clear() - return - yield class FillFieldAction(Action): @@ -223,12 +220,12 @@ def _get_data(self, params): return self.template.render(params) async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] params = user.parametrizer.collect(text_preprocessing_result) data = self._get_data(params) self._fill(user, data) - return - yield + return commands class CompositeFillFieldAction(FillFieldAction): @@ -255,7 +252,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.scenario: UnifiedTemplate = UnifiedTemplate(items["scenario"]) async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] if params is None: params = user.parametrizer.collect(text_preprocessing_result) else: @@ -263,18 +261,19 @@ async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessing scenario_id = self.scenario.render(params) scenario = user.descriptions["scenarios"].get(scenario_id) if scenario: - async for command in scenario.run(text_preprocessing_result, user, params): - yield command + commands.extend(await scenario.run(text_preprocessing_result, user, params)) + return commands class RunLastScenarioAction(Action): async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] last_scenario_id = user.last_scenarios.last_scenario_name scenario = user.descriptions["scenarios"].get(last_scenario_id) if scenario: - async for command in scenario.run(text_preprocessing_result, user, params): - yield command + commands.extend(await scenario.run(text_preprocessing_result, user, params)) + return commands class ChoiceScenarioAction(Action): @@ -304,20 +303,21 @@ def build_else_item(self): return self._else_item async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] choice_is_made = False for scenario, requirement in zip(self._scenarios, self.requirement_items): check_res = requirement.check(text_preprocessing_result, user, params) if check_res: - async for command in RunScenarioAction(items=scenario).run(user, text_preprocessing_result, params): - yield command + commands.extend(await RunScenarioAction(items=scenario).run(user, text_preprocessing_result, params)) choice_is_made = True break if not choice_is_made and self._else_item: - async for command in self.else_item.run(user, text_preprocessing_result, params): - yield command + commands.extend(await self.else_item.run(user, text_preprocessing_result, params) or []) + + return commands def clear_scenario(user, scenario_id): @@ -328,20 +328,20 @@ def clear_scenario(user, scenario_id): class ClearCurrentScenarioAction(Action): async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] last_scenario_id = user.last_scenarios.last_scenario_name if last_scenario_id: clear_scenario(user, last_scenario_id) - return - yield + return commands class ClearAllScenariosAction(Action): async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] user.last_scenarios.clear_all() - return - yield + return commands class ClearScenarioByIdAction(Action): @@ -354,11 +354,11 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.scenario_id = items.get("scenario_id") async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] if self.scenario_id: clear_scenario(user, self.scenario_id) - return - yield + return commands class ClearCurrentScenarioFormAction(Action): @@ -366,12 +366,12 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super().__init__(items, id) async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] last_scenario_id = user.last_scenarios.last_scenario_name if last_scenario_id: user.forms.clear_form(last_scenario_id) - return - yield + return commands class ResetCurrentNodeAction(Action): @@ -380,12 +380,12 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.node_id = items.get('node_id', None) async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] last_scenario_id = user.last_scenarios.last_scenario_name if last_scenario_id: user.scenario_models[last_scenario_id].current_node = self.node_id - return - yield + return commands class AddHistoryEventAction(Action): @@ -403,7 +403,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.event_content[k] = UnifiedTemplate(v) async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] last_scenario_id = user.last_scenarios.last_scenario_name scenario = user.descriptions["scenarios"].get(last_scenario_id) if scenario: @@ -424,38 +425,39 @@ async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessing content=self.event_content ) user.history.add_event(event) - return - yield + return commands class EmptyAction(Action): async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] log("%(class_name)s.run: action do nothing.", params={log_const.KEY_NAME: "empty_action", "class_name": self.__class__.__name__}, user=user) - return - yield + return commands class RunScenarioByProjectNameAction(Action): async def run(self, user: User, text_preprocessing_result: TextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] scenario_id = user.message.project_name scenario = user.descriptions["scenarios"].get(scenario_id) if scenario: - async for command in scenario.run(text_preprocessing_result, user, params): - yield command + commands.extend(await scenario.run(text_preprocessing_result, user, params)) else: log("%(class_name)s warning: %(scenario_id)s isn't exist", params={log_const.KEY_NAME: "warning_in_RunScenarioByProjectNameAction", "class_name": self.__class__.__name__, "scenario_id": scenario_id}, user=user, level="WARNING") + return commands class ProcessBehaviorAction(Action): async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] callback_id = user.message.callback_id log(f"%(class_name)s.run: got callback_id %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s.", @@ -468,14 +470,13 @@ async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessing params={log_const.KEY_NAME: "process_behavior_action_warning", "class_name": self.__class__.__name__, log_const.BEHAVIOR_CALLBACK_ID_VALUE: callback_id}, level="WARNING", user=user) - return + return commands if user.message.payload: - async for command in user.behaviors.success(callback_id): - yield command + commands.extend(await user.behaviors.success(callback_id)) else: - async for command in user.behaviors.fail(callback_id): - yield command + commands.extend(await user.behaviors.fail(callback_id)) + return commands class SelfServiceActionWithState(BasicSelfServiceActionWithState): @@ -490,7 +491,7 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.rewrite_saved_messages = items.get("rewrite_saved_messages", False) self._check_scenario: bool = items.get("check_scenario", True) - async def _run(self, user, text_preprocessing_result, params=None) -> AsyncGenerator[Command, None]: + async def _run(self, user, text_preprocessing_result, params=None) -> List[Command]: action_params = copy.copy(params or {}) command_params = dict() @@ -513,8 +514,9 @@ async def _run(self, user, text_preprocessing_result, params=None) -> AsyncGener save_params = self._get_save_params(user, action_params, command_params) self._save_behavior(callback_id, user, scenario, text_preprocessing_result, save_params) - yield Command(self.command, command_params, self.id, request_type=self.request_type, - request_data=request_data) + commands = [Command(self.command, command_params, self.id, request_type=self.request_type, + request_data=request_data)] + return commands def _get_extra_request_data(self, user, params, callback_id): extra_request_data = {} diff --git a/scenarios/behaviors/behaviors.py b/scenarios/behaviors/behaviors.py index 0e4b9924..c50afa8a 100644 --- a/scenarios/behaviors/behaviors.py +++ b/scenarios/behaviors/behaviors.py @@ -3,14 +3,15 @@ import socket from collections import namedtuple from time import time -from typing import Dict, AsyncGenerator +from typing import Dict, List import scenarios.logging.logger_constants as log_const from core.basic_models.actions.command import Command from core.logging.logger_utils import log +from core.names.field import APP_INFO from core.text_preprocessing.preprocessing_result import TextPreprocessingResult from core.utils.pickle_copy import pickle_deepcopy -from scenarios.actions.action_params_names import TO_MESSAGE_NAME, LOCAL_VARS +from scenarios.actions.action_params_names import TO_MESSAGE_NAME, TO_MESSAGE_PARAMS, LOCAL_VARS from core.monitoring.monitoring import monitoring @@ -128,8 +129,9 @@ def _log_callback( params=log_params, ) - async def success(self, callback_id: str) -> AsyncGenerator[Command, None]: + async def success(self, callback_id: str) -> List[Command]: callback = self._get_callback(callback_id) + result = [] if callback: self._check_hostname(callback_id, callback) self._add_returned_callback(callback_id) @@ -143,18 +145,19 @@ async def success(self, callback_id: str) -> AsyncGenerator[Command, None]: callback_action_params, ) text_preprocessing_result = TextPreprocessingResult(callback.text_preprocessing_result) - async for command in behavior.success_action.run(self._user, text_preprocessing_result, - callback_action_params): - yield command + result = await behavior.success_action.run(self._user, text_preprocessing_result, callback_action_params) + result = result or [] else: log(f"behavior.success not found valid callback for callback_id %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s", self._user, params={log_const.KEY_NAME: log_const.BEHAVIOR_SUCCESS_VALUE, log_const.BEHAVIOR_CALLBACK_ID_VALUE: callback_id}) self._delete(callback_id) + return result - async def fail(self, callback_id: str) -> AsyncGenerator[Command, None]: + async def fail(self, callback_id: str) -> List[Command]: callback = self._get_callback(callback_id) + result = [] if callback: self._check_hostname(callback_id, callback) self._add_returned_callback(callback_id) @@ -164,18 +167,18 @@ async def fail(self, callback_id: str) -> AsyncGenerator[Command, None]: callback_id, "behavior_fail", monitoring.counter_behavior_fail, "fail", callback_action_params ) text_preprocessing_result = TextPreprocessingResult(callback.text_preprocessing_result) - async for command in behavior.fail_action.run(self._user, text_preprocessing_result, - callback_action_params): - yield command + result = await behavior.fail_action.run(self._user, text_preprocessing_result, callback_action_params) or [] else: log(f"behavior.fail not found valid callback for callback_id %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s", self._user, params={log_const.KEY_NAME: log_const.BEHAVIOR_FAIL_VALUE, log_const.BEHAVIOR_CALLBACK_ID_VALUE: callback_id}) self._delete(callback_id) + return result - async def timeout(self, callback_id: str) -> AsyncGenerator[Command, None]: + async def timeout(self, callback_id: str) -> List[Command]: callback = self._get_callback(callback_id) + result = [] if callback: self._add_returned_callback(callback_id) behavior = self.descriptions[callback.behavior_id] @@ -184,18 +187,19 @@ async def timeout(self, callback_id: str) -> AsyncGenerator[Command, None]: callback_id, "behavior_timeout", monitoring.counter_behavior_timeout, "timeout", callback_action_params ) text_preprocessing_result = TextPreprocessingResult(callback.text_preprocessing_result) - async for command in behavior.timeout_action.run(self._user, text_preprocessing_result, - callback_action_params): - yield command + result = await behavior.timeout_action.run(self._user, text_preprocessing_result, callback_action_params) + result = result or [] else: log(f"behavior.timeout not found valid callback for callback_id %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s", self._user, params={log_const.KEY_NAME: log_const.BEHAVIOR_TIMEOUT_VALUE, log_const.BEHAVIOR_CALLBACK_ID_VALUE: callback_id}) self._delete(callback_id) + return result - async def misstate(self, callback_id: str) -> AsyncGenerator[Command, None]: + async def misstate(self, callback_id: str) -> List[Command]: callback = self._get_callback(callback_id) + result = [] if callback: self._check_hostname(callback_id, callback) self._add_returned_callback(callback_id) @@ -209,9 +213,8 @@ async def misstate(self, callback_id: str) -> AsyncGenerator[Command, None]: callback_action_params, ) text_preprocessing_result = TextPreprocessingResult(callback.text_preprocessing_result) - async for command in behavior.misstate_action.run(self._user, text_preprocessing_result, - callback_action_params): - yield command + result = await behavior.misstate_action.run(self._user, text_preprocessing_result, callback_action_params) + result = result or [] else: log("behavior.misstate not found valid callback" f" for callback_id %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s", @@ -219,6 +222,7 @@ async def misstate(self, callback_id: str) -> AsyncGenerator[Command, None]: params={log_const.KEY_NAME: log_const.BEHAVIOR_MISSTATE_VALUE, log_const.BEHAVIOR_CALLBACK_ID_VALUE: callback_id}) self._delete(callback_id) + return result def _get_callback(self, callback_id): callback = self._callbacks.get(callback_id) diff --git a/scenarios/scenario_descriptions/form_filling_scenario.py b/scenarios/scenario_descriptions/form_filling_scenario.py index fa4780ad..51790032 100644 --- a/scenarios/scenario_descriptions/form_filling_scenario.py +++ b/scenarios/scenario_descriptions/form_filling_scenario.py @@ -47,8 +47,7 @@ async def ask_again(self, text_preprocessing_result, user, params): content={HistoryConstants.content_fields.FIELD: question_field.description.id}, result=HistoryConstants.event_results.ASK_QUESTION)) - async for command in question.run(user, text_preprocessing_result, params): - yield command + return await question.run(user, text_preprocessing_result, params) def _check_field(self, text_preprocessing_result, user, params): form = user.forms[self.form_type] @@ -115,6 +114,7 @@ def _extract_data(self, form, text_normalization_result, user, params): async def _validate_extracted_data(self, user, text_preprocessing_result, form, data_extracted, params) -> List[Command]: + error_msgs = [] for field_key, field in form.description.fields.items(): value = data_extracted.get(field_key) # is not None is necessary, because 0 and False should be checked, None - shouldn't fill @@ -126,9 +126,9 @@ async def _validate_extracted_data(self, user, text_preprocessing_result, form, message = "Field is not valid: %(field_key)s" log(message, user, log_params) actions = field.field_validator.actions - async for command in self.get_action_results(user, text_preprocessing_result, actions): - yield command + error_msgs = await self.get_action_results(user, text_preprocessing_result, actions) break + return error_msgs async def _fill_form(self, user, text_preprocessing_result, form, data_extracted) -> Tuple[List[Command], bool]: on_filled_actions = [] @@ -160,8 +160,7 @@ async def get_reply(self, user, text_preprocessing_result, reply_actions, field, message = "Ask question on field: %(field)s" log(message, user, params) action_params[REQUEST_FIELD] = {"type": field.description.type, "id": field.description.id} - async for command in self.get_action_results(user, text_preprocessing_result, actions, action_params): - yield command + action_messages = await self.get_action_results(user, text_preprocessing_result, actions, action_params) else: actions = reply_actions params = { @@ -171,9 +170,9 @@ async def get_reply(self, user, text_preprocessing_result, reply_actions, field, message = "Finished scenario: %(id)s" log(message, user, params) user.preprocessing_messages_for_scenarios.clear() - async for command in self.get_action_results(user, text_preprocessing_result, actions, action_params): - yield command + action_messages = await self.get_action_results(user, text_preprocessing_result, actions, action_params) user.last_scenarios.delete(self.id) + return action_messages @monitoring.got_histogram("scenario_time") async def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None) -> List[Command]: diff --git a/scenarios/scenario_descriptions/tree_scenario/tree_scenario.py b/scenarios/scenario_descriptions/tree_scenario/tree_scenario.py index 80469e0d..0adfa7b4 100644 --- a/scenarios/scenario_descriptions/tree_scenario/tree_scenario.py +++ b/scenarios/scenario_descriptions/tree_scenario/tree_scenario.py @@ -14,7 +14,6 @@ class TreeScenario(FormFillingScenario): - """Warning: Not adapted for async generator interfaces of actions yet""" def __init__(self, items, id): super(TreeScenario, self).__init__(items, id) self._start_node_key = items["start_node_key"] diff --git a/smart_kit/action/http.py b/smart_kit/action/http.py index f2fe1fb8..7f383630 100644 --- a/smart_kit/action/http.py +++ b/smart_kit/action/http.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional, Dict, Union, Any, AsyncGenerator +from typing import Optional, Dict, Union, List, Any import aiohttp import aiohttp.client_exceptions @@ -114,7 +114,7 @@ def _log_response(self, user, response, additional_params=None): **additional_params, }) - async def process_result(self, response, user, text_preprocessing_result, params) -> AsyncGenerator[Command, None]: + async def process_result(self, response, user, text_preprocessing_result, params): behavior_description = user.descriptions["behaviors"][self.behavior] if self.behavior else None action = None if self.error is None: @@ -130,11 +130,10 @@ async def process_result(self, response, user, text_preprocessing_result, params else: action = behavior_description.fail_action if action: - async for command in action.run(user, text_preprocessing_result, None): - yield command + return await action.run(user, text_preprocessing_result, None) async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: self.preprocess(user, text_preprocessing_result, params) params = params or {} request_parameters = self._get_request_params(user, text_preprocessing_result, params) @@ -142,5 +141,4 @@ async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreproces response = await self._make_response(request_parameters, user) if response: log("response data: %(body)s", params={"body": response.json()}, level="INFO") - async for command in self.process_result(response, user, text_preprocessing_result, params): - yield command + return await self.process_result(response, user, text_preprocessing_result, params) diff --git a/smart_kit/action/smart_geo_action.py b/smart_kit/action/smart_geo_action.py index a4818083..d6576c1f 100644 --- a/smart_kit/action/smart_geo_action.py +++ b/smart_kit/action/smart_geo_action.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional, Union, AsyncGenerator +from typing import Dict, Any, Optional, Union, List from core.basic_models.actions.command import Command from core.basic_models.actions.string_actions import StringAction @@ -35,10 +35,10 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.behavior = items.get("behavior", "smart_geo_behavior") async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: scenario_id = user.last_scenarios.last_scenario_name user.behaviors.add(user.message.generate_new_callback_id(), self.behavior, scenario_id, text_preprocessing_result.raw, action_params=pickle_deepcopy(params)) - async for command in super().run(user, text_preprocessing_result, params): - yield command + commands = await super().run(user, text_preprocessing_result, params) + return commands diff --git a/smart_kit/handlers/handle_close_app.py b/smart_kit/handlers/handle_close_app.py index fff4617d..e5c604f4 100644 --- a/smart_kit/handlers/handle_close_app.py +++ b/smart_kit/handlers/handle_close_app.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, AsyncGenerator +from typing import List, Any, Dict from core.basic_models.actions.command import Command from core.logging.logger_utils import log @@ -15,17 +15,14 @@ def __init__(self, app_name: str): super().__init__(app_name) self._clear_current_scenario = ClearCurrentScenarioAction() - async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Command, None]: - async for command in super().run(payload, user): - yield command - + async def run(self, payload: Dict[str, Any], user: User) -> List[Command]: + commands = await super().run(payload, user) text_preprocessing_result = TextPreprocessingResult.from_payload(payload) - async for command in self._clear_current_scenario.run(user, text_preprocessing_result): - yield command - params = { log_const.KEY_NAME: "HandlerCloseApp" } + await self._clear_current_scenario.run(user, text_preprocessing_result) if payload.get("message"): params["tpr_str"] = str(text_preprocessing_result.raw) log("HandlerCloseApp with text preprocessing result", user, params) + return commands diff --git a/smart_kit/handlers/handle_respond.py b/smart_kit/handlers/handle_respond.py index db7dfd62..8e5cce4d 100644 --- a/smart_kit/handlers/handle_respond.py +++ b/smart_kit/handlers/handle_respond.py @@ -1,5 +1,5 @@ from time import time -from typing import Optional, Dict, Any, AsyncGenerator +from typing import List, Optional, Dict, Any from core.basic_models.actions.command import Command from core.logging.logger_utils import log @@ -26,9 +26,8 @@ def get_action_params(self, payload: Dict[str, Any], user: User): callback_id = user.message.callback_id return user.behaviors.get_callback_action_params(callback_id) - async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Command, None]: - async for command in super().run(payload, user): - yield command + async def run(self, payload: Dict[str, Any], user: User) -> List[Command]: + commands = await super().run(payload, user) callback_id = user.message.callback_id action_params = self.get_action_params(payload, user) action_name = self.get_action_name(payload, user) @@ -57,8 +56,8 @@ async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Comma log("text preprocessing result: '%(normalized_text)s'", user, params, level="DEBUG") action = user.descriptions["external_actions"][action_name] - async for command in action.run(user, text_preprocessing_result, action_params): - yield command + commands.extend(await action.run(user, text_preprocessing_result, action_params) or []) + return commands @staticmethod def get_processing_time(user: User): diff --git a/smart_kit/handlers/handle_server_action.py b/smart_kit/handlers/handle_server_action.py index 4c6a4b5c..028304b8 100644 --- a/smart_kit/handlers/handle_server_action.py +++ b/smart_kit/handlers/handle_server_action.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional, AsyncGenerator +from typing import List, Dict, Any, Optional import scenarios.logging.logger_constants as log_const from core.basic_models.actions.command import Command @@ -25,9 +25,8 @@ def get_action_name(self, payload: Dict[str, Any], user: User): def get_action_params(self, payload: Dict[str, Any]): return payload[SERVER_ACTION].get("parameters", {}) - async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Command, None]: - async for command in super().run(payload, user): - yield command + async def run(self, payload: Dict[str, Any], user: User) -> List[Command]: + commands = await super().run(payload, user) action_params = pickle_deepcopy(self.get_action_params(payload)) params = {log_const.KEY_NAME: "handling_server_action", "server_action_params": str(action_params), @@ -36,5 +35,5 @@ async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Comma action_id = self.get_action_name(payload, user) action = user.descriptions["external_actions"][action_id] - async for command in action.run(user, TextPreprocessingResult({}), action_params): - yield command + commands.extend(await action.run(user, TextPreprocessingResult({}), action_params) or []) + return commands diff --git a/smart_kit/handlers/handle_take_runtime_permissions.py b/smart_kit/handlers/handle_take_runtime_permissions.py index a8b793e0..b727f57e 100644 --- a/smart_kit/handlers/handle_take_runtime_permissions.py +++ b/smart_kit/handlers/handle_take_runtime_permissions.py @@ -1,4 +1,4 @@ -from typing import AsyncGenerator +from typing import List from core.basic_models.actions.command import Command from core.logging.logger_utils import log @@ -11,15 +11,13 @@ class HandlerTakeRuntimePermissions(HandlerBase): SUCCESS_CODE = 1 - async def run(self, payload, user: User) -> AsyncGenerator[Command, None]: - async for command in super().run(payload, user): - yield command + async def run(self, payload, user: User) -> List[Command]: + commands = await super().run(payload, user) log(f"{self.__class__.__name__} started", user) if payload.get(STATUS_CODE, {}).get(CODE) == self.SUCCESS_CODE: - async for command in user.behaviors.success(user.message.callback_id): - yield command + commands.extend(await user.behaviors.success(user.message.callback_id)) user.variables.set("permitted_actions", payload.get(PERMITTED_ACTIONS, [])) else: - async for command in user.behaviors.fail(user.message.callback_id): - yield command + commands.extend(await user.behaviors.fail(user.message.callback_id)) user.variables.set("take_runtime_permissions_status_code", payload.get(STATUS_CODE, {})) + return commands diff --git a/smart_kit/handlers/handler_base.py b/smart_kit/handlers/handler_base.py index 6d93493c..aee2bf67 100644 --- a/smart_kit/handlers/handler_base.py +++ b/smart_kit/handlers/handler_base.py @@ -1,5 +1,5 @@ # coding: utf-8 -from typing import Dict, Any, AsyncGenerator +from typing import List, Dict, Any, Optional from core.basic_models.actions.command import Command from core.monitoring.monitoring import monitoring @@ -13,9 +13,8 @@ class HandlerBase: def __init__(self, app_name: str): self.app_name = app_name - async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Command, None]: + async def run(self, payload: Dict[str, Any], user: User) -> Optional[List[Command]]: # отправка события о входящем сообщении в систему мониторинга monitoring.counter_incoming(self.app_name, user.message.message_name, self.__class__.__name__, user, app_info=user.message.app_info) - return - yield + return [] diff --git a/smart_kit/handlers/handler_run_app.py b/smart_kit/handlers/handler_run_app.py index e2d9e9e5..cc50b010 100644 --- a/smart_kit/handlers/handler_run_app.py +++ b/smart_kit/handlers/handler_run_app.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, AsyncGenerator +from typing import Dict, Any, List import scenarios.logging.logger_constants as log_const from core.basic_models.actions.command import Command @@ -21,16 +21,15 @@ def __init__(self, app_name: str, dialogue_manager: DialogueManager): f"{self.__class__.__name__}.__init__ finished.", params={log_const.KEY_NAME: log_const.STARTUP_VALUE} ) - async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Command, None]: - async for command in super().run(payload, user): - yield command + async def run(self, payload: Dict[str, Any], user: User) -> List[Command]: + commands = await super().run(payload, user) params = {log_const.KEY_NAME: "handling_run_app"} log(f"{self.__class__.__name__}.run started", user, params) - async for command in self._handle_base(user): - yield command + commands.extend(await self._handle_base(user)) + return commands - async def _handle_base(self, user: User) -> AsyncGenerator[Command, None]: - async for command in self.dialogue_manager.run(TextPreprocessingResult({}), user): - yield command + async def _handle_base(self, user: User) -> List[Command]: + answer, is_answer_found = await self.dialogue_manager.run(TextPreprocessingResult({}), user) + return answer or [] diff --git a/smart_kit/handlers/handler_take_profile_data.py b/smart_kit/handlers/handler_take_profile_data.py index d20b225f..e2211ab3 100644 --- a/smart_kit/handlers/handler_take_profile_data.py +++ b/smart_kit/handlers/handler_take_profile_data.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, AsyncGenerator +from typing import List, Any, Dict from core.basic_models.actions.command import Command from core.logging.logger_utils import log @@ -11,15 +11,13 @@ class HandlerTakeProfileData(HandlerBase): SUCCESS_CODE = 1 - async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Command, None]: - async for command in super().run(payload, user): - yield command + async def run(self, payload: Dict[str, Any], user: User) -> List[Command]: + commands = await super().run(payload, user) log(f"{self.__class__.__name__} started", user) if payload.get(STATUS_CODE, {}).get(CODE) == self.SUCCESS_CODE: - async for command in user.behaviors.success(user.message.callback_id): - yield command + commands.extend(await user.behaviors.success(user.message.callback_id)) user.variables.set("smart_geo", payload.get(PROFILE_DATA, {}).get(GEO)) else: - async for command in user.behaviors.fail(user.message.callback_id): - yield command + commands.extend(await user.behaviors.fail(user.message.callback_id)) + return commands diff --git a/smart_kit/handlers/handler_text.py b/smart_kit/handlers/handler_text.py index 9aab9bf4..f1c116eb 100644 --- a/smart_kit/handlers/handler_text.py +++ b/smart_kit/handlers/handler_text.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, AsyncGenerator +from typing import List, Dict, Any import scenarios.logging.logger_constants as log_const from core.basic_models.actions.command import Command @@ -21,9 +21,8 @@ def __init__(self, app_name: str, dialogue_manager: DialogueManager): f"{self.__class__.__name__}.__init__ finished.", params={log_const.KEY_NAME: log_const.STARTUP_VALUE} ) - async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Command, None]: - async for command in super().run(payload, user): - yield command + async def run(self, payload: Dict[str, Any], user: User) -> List[Command]: + commands = await super().run(payload, user) text_preprocessing_result = TextPreprocessingResult.from_payload(payload) @@ -33,10 +32,9 @@ async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Comma } log("text preprocessing result: '%(normalized_text)s'", user, params) - async for command in self._handle_base(text_preprocessing_result, user): - yield command + commands.extend(await self._handle_base(text_preprocessing_result, user)) + return commands - async def _handle_base(self, text_preprocessing_result: TextPreprocessingResult, - user: User) -> AsyncGenerator[Command, None]: - async for command in self.dialogue_manager.run(text_preprocessing_result, user): - yield command + async def _handle_base(self, text_preprocessing_result: TextPreprocessingResult, user: User) -> List[Command]: + answer, is_answer_found = await self.dialogue_manager.run(text_preprocessing_result, user) + return answer or [] diff --git a/smart_kit/handlers/handler_timeout.py b/smart_kit/handlers/handler_timeout.py index 8214389f..ff78b313 100644 --- a/smart_kit/handlers/handler_timeout.py +++ b/smart_kit/handlers/handler_timeout.py @@ -1,5 +1,5 @@ # coding: utf-8 -from typing import Dict, Any, AsyncGenerator +from typing import List, Dict, Any from core.basic_models.actions.command import Command from core.logging.logger_utils import log @@ -14,7 +14,8 @@ class HandlerTimeout(HandlerBase): - async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Command, None]: + async def run(self, payload: Dict[str, Any], user: User) -> List[Command]: + commands = [] callback_id = user.message.callback_id if user.behaviors.has_callback(callback_id): params = {log_const.KEY_NAME: "handling_timeout"} @@ -31,5 +32,5 @@ async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Comma user, app_info=app_info) callback_id = user.message.callback_id - async for command in user.behaviors.timeout(callback_id): - yield command + commands.extend(await user.behaviors.timeout(callback_id)) + return commands diff --git a/smart_kit/models/dialogue_manager.py b/smart_kit/models/dialogue_manager.py index 23769dc6..35fafc42 100644 --- a/smart_kit/models/dialogue_manager.py +++ b/smart_kit/models/dialogue_manager.py @@ -1,8 +1,6 @@ # coding: utf-8 from functools import cached_property -from typing import AsyncGenerator -from core.basic_models.actions.command import Command from core.logging.logger_utils import log from core.names import field @@ -30,12 +28,11 @@ def __init__(self, scenario_descriptions, app_name, **kwargs): def _nothing_found_action(self): return self.actions.get(self.NOTHING_FOUND_ACTION) or NothingFoundAction() - async def run(self, text_preprocessing_result, user) -> AsyncGenerator[Command, None]: + async def run(self, text_preprocessing_result, user): before_action = user.descriptions["external_actions"].get("before_action") if before_action: params = user.parametrizer.collect(text_preprocessing_result) - async for command in before_action.run(user, text_preprocessing_result, params): - yield command + await before_action.run(user, text_preprocessing_result, params) scenarios_names = user.last_scenarios.scenarios_names scenario_key = user.message.payload[field.INTENT] if scenario_key in scenarios_names: @@ -45,17 +42,13 @@ async def run(self, text_preprocessing_result, user) -> AsyncGenerator[Command, if not scenario.text_fits(text_preprocessing_result, user): params = user.parametrizer.collect(text_preprocessing_result) if scenario.check_ask_again_requests(text_preprocessing_result, user, params): - async for command in scenario.ask_again(text_preprocessing_result, user, params): - yield command - return + reply = await scenario.ask_again(text_preprocessing_result, user, params) + return reply, True monitoring.counter_nothing_found(self.app_name, scenario_key, user) - async for command in self._nothing_found_action.run(user, text_preprocessing_result): - yield command - return - async for command in self.run_scenario(scenario_key, text_preprocessing_result, user): - yield command + return await self._nothing_found_action.run(user, text_preprocessing_result), False + return await self.run_scenario(scenario_key, text_preprocessing_result, user), True - async def run_scenario(self, scen_id, text_preprocessing_result, user) -> AsyncGenerator[Command, None]: + async def run_scenario(self, scen_id, text_preprocessing_result, user): initial_last_scenario = user.last_scenarios.last_scenario_name scenario = self.scenarios[scen_id] params = {log_const.KEY_NAME: log_const.CHOSEN_SCENARIO_VALUE, @@ -63,9 +56,10 @@ async def run_scenario(self, scen_id, text_preprocessing_result, user) -> AsyncG log_const.SCENARIO_DESCRIPTION_VALUE: scenario.scenario_description } log(log_const.LAST_SCENARIO_MESSAGE, user, params) - async for command in scenario.run(text_preprocessing_result, user): - yield command + run_scenario_result = await scenario.run(text_preprocessing_result, user) actual_last_scenario = user.last_scenarios.last_scenario_name if actual_last_scenario and actual_last_scenario != initial_last_scenario: monitoring.counter_scenario_change(self.app_name, actual_last_scenario, user) + + return run_scenario_result diff --git a/smart_kit/models/smartapp_model.py b/smart_kit/models/smartapp_model.py index fb7a7924..609eb6dc 100644 --- a/smart_kit/models/smartapp_model.py +++ b/smart_kit/models/smartapp_model.py @@ -1,7 +1,7 @@ # coding: utf-8 import sys import traceback -from typing import AsyncGenerator +from typing import List, Optional from core.basic_models.actions.command import Command from core.descriptions.descriptions import Descriptions @@ -73,19 +73,20 @@ def init_additional_handlers(self): }) @exc_handler(on_error_obj_method_name="on_answer_error") - async def answer(self, message: SmartAppFromMessage, user: User) -> AsyncGenerator[Command, None]: + async def answer(self, message: SmartAppFromMessage, user: User) -> Optional[List[Command]]: user.expire() user.message_vars.clear() handler = self.get_handler(message.type) if not user.load_error: - async for command in handler.run(message.payload, user): - yield command + commands = await handler.run(message.payload, user) else: log("Error in loading user data", user, level="ERROR", exc_info=True) raise Exception("Error in loading user data") - async def on_answer_error(self, message, user) -> AsyncGenerator[Command, None]: + return commands + + async def on_answer_error(self, message, user): user.do_not_save = True monitoring.counter_exception(self.app_name) params = {log_const.KEY_NAME: log_const.DIALOG_ERROR_VALUE, @@ -101,6 +102,6 @@ async def on_answer_error(self, message, user) -> AsyncGenerator[Command, None]: if user.settings["template_settings"].get("debug_info"): set_debug_info(self.app_name, callback_action_params, error) exception_action = user.descriptions["external_actions"]["exception_action"] - async for command in exception_action.run(user=user, text_preprocessing_result=None, - params=callback_action_params): - yield command + commands = await exception_action.run(user=user, text_preprocessing_result=None, + params=callback_action_params) + return commands diff --git a/smart_kit/start_points/main_loop_kafka.py b/smart_kit/start_points/main_loop_kafka.py index 5a4dc349..6573d8d1 100644 --- a/smart_kit/start_points/main_loop_kafka.py +++ b/smart_kit/start_points/main_loop_kafka.py @@ -493,34 +493,21 @@ async def process_message(self, mq_message: KafkaMessage, consumer, kafka_key, s user=user, level="WARNING") user.local_vars.set(KAFKA_REPLY_TOPIC, message.headers[KAFKA_REPLY_TOPIC]) - publish_time_ms_sum = 0 with StatsTimer() as script_timer: - async for command in self.model.answer(message, user): - answers = self._generate_answers(user=user, commands=[command], message=message, - topic_key=topic_key, kafka_key=kafka_key) - if answers: - for answer in answers: - with StatsTimer() as publish_timer: - self._send_request(user, answer, mq_message) - publish_time_ms_sum += publish_timer.msecs - stats += "Publishing time: {} msecs\n".format(publish_timer.msecs) - - script_time_ms = script_timer.msecs - publish_time_ms_sum - script_time_sec = script_time_ms / 1000 - monitoring.sampling_script_time(self.app_name, script_time_sec) - stats += "Script time: {} msecs\n".format(script_time_ms) + commands = await self.model.answer(message, user) + + answers = self._generate_answers(user=user, commands=commands, message=message, + topic_key=topic_key, + kafka_key=kafka_key) + monitoring.sampling_script_time(self.app_name, script_timer.secs) + stats += "Script time: {} msecs\n".format(script_timer.msecs) with StatsTimer() as save_timer: user_save_no_collisions = await self.save_user(db_uid, user, message) monitoring.sampling_save_time(self.app_name, save_timer.secs) stats += "Saving time: {} msecs\n".format(save_timer.msecs) - - log(stats, user=user) - - if user_save_no_collisions: - self.save_behavior_timeouts(user, mq_message, kafka_key) - else: + if not user_save_no_collisions: log("MainLoop.iterate: save user got collision on uid %(uid)s db_version %(db_version)s.", user=user, params={log_const.KEY_NAME: "ignite_collision", @@ -532,6 +519,14 @@ async def process_message(self, mq_message: KafkaMessage, consumer, kafka_key, s "db_version": str(user.private_vars.get(user.USER_DB_VERSION))}, level="WARNING") continue + + if answers: + self.save_behavior_timeouts(user, mq_message, kafka_key) + for answer in answers: + with StatsTimer() as publish_timer: + self._send_request(user, answer, mq_message) + stats += "Publishing time: {} msecs\n".format(publish_timer.msecs) + log(stats, user=user) else: try: data = message.masked_value @@ -673,18 +668,15 @@ async def do_behavior_timeout(self, kwargs, worker_kwargs): if user.behaviors.has_callback(callback_id): callback_found = True + commands = await self.model.answer(timeout_from_message, user) topic_key = self._get_topic_key(mq_message, kafka_key) - async for command in self.model.answer(timeout_from_message, user): - answers = self._generate_answers(user=user, commands=[command], message=timeout_from_message, - topic_key=topic_key, kafka_key=kafka_key) - for answer in answers: - self._send_request(user, answer, mq_message) + answers = self._generate_answers(user=user, commands=commands, message=timeout_from_message, + topic_key=topic_key, + kafka_key=kafka_key) user_save_ok = await self.save_user(db_uid, user, mq_message) - if user_save_ok: - self.save_behavior_timeouts(user, mq_message, kafka_key) - else: + if not user_save_ok: log("MainLoop.do_behavior_timeout: save user got collision on uid %(uid)s db_version " "%(db_version)s.", user=user, @@ -713,6 +705,11 @@ async def do_behavior_timeout(self, kwargs, worker_kwargs): "db_version": str(user.private_vars.get(user.USER_DB_VERSION))}, level="WARNING") monitoring.counter_save_collision_tries_left(self.app_name) + + if user_save_ok: + self.save_behavior_timeouts(user, mq_message, kafka_key) + for answer in answers: + self._send_request(user, answer, mq_message) except Exception: log("%(class_name)s error.", params={log_const.KEY_NAME: "error_handling_timeout", "class_name": self.__class__.__name__, diff --git a/smart_kit/system_answers/nothing_found_action.py b/smart_kit/system_answers/nothing_found_action.py index 5252b586..906ed071 100644 --- a/smart_kit/system_answers/nothing_found_action.py +++ b/smart_kit/system_answers/nothing_found_action.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Optional, Union, AsyncGenerator +from typing import Dict, Any, Optional, List, Union from core.basic_models.actions.basic_actions import Action from core.basic_models.actions.string_actions import StringAction @@ -19,6 +19,7 @@ def __init__(self, items: Dict[str, Any] = None, id: Optional[str] = None): self._action = StringAction({"command": NOTHING_FOUND}) async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: - async for command in self._action.run(user, text_preprocessing_result, params=params): - yield command + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] + commands.extend(await self._action.run(user, text_preprocessing_result, params=params) or []) + return commands diff --git a/smart_kit/template/app/basic_entities/actions.py-tpl b/smart_kit/template/app/basic_entities/actions.py-tpl index 54a003ef..53c5e205 100644 --- a/smart_kit/template/app/basic_entities/actions.py-tpl +++ b/smart_kit/template/app/basic_entities/actions.py-tpl @@ -1,4 +1,4 @@ -from typing import Union, Dict, Any, Optional, AsyncGenerator +from typing import Union, Dict, Any, Optional, List from core.basic_models.actions.basic_actions import Action from core.text_preprocessing.preprocessing_result import TextPreprocessingResult @@ -18,7 +18,7 @@ class CustomAction(Action): self.test_param = items.get("test_param") async def run(self, user: User, text_preprocessing_result: TextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> AsyncGenerator[Command, None]: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + commands = [] # тело действия - return - yield # для поддержания интерфейса генератора + return commands diff --git a/smart_kit/template/app/handlers/handlers.py-tpl b/smart_kit/template/app/handlers/handlers.py-tpl index a90bedfd..94f75d8b 100644 --- a/smart_kit/template/app/handlers/handlers.py-tpl +++ b/smart_kit/template/app/handlers/handlers.py-tpl @@ -1,4 +1,4 @@ -from typing import Dict, Any, AsyncGenerator +from typing import List, Dict, Any from core.basic_models.actions.command import Command from smart_kit.handlers.handler_base import HandlerBase @@ -11,6 +11,6 @@ class CustomHandler(HandlerBase): Для сопоставления типа сообщения обработчику, добавьте такую пару в CustomModel.additional_handlers. По умолчанию, данный обработчик сопоставлен типу сообщения "CUSTOM_MESSAGE_NAME". """ - async def run(self, payload: Dict[str, Any], user: User) -> AsyncGenerator[Command, None]: - async for command in super().run(payload, user): - yield command + async def run(self, payload: Dict[str, Any], user: User) -> List[Command]: + commands = await super().run(payload, user) + return commands diff --git a/smart_kit/testing/suite.py b/smart_kit/testing/suite.py index 5746527d..f715c968 100644 --- a/smart_kit/testing/suite.py +++ b/smart_kit/testing/suite.py @@ -232,13 +232,11 @@ async def _run(self) -> bool: self.post_setup_user(user) - commands = [] - answers = [] - async for command in self.app_model.answer(message, user): - commands.append(command) - answers.extend(self._generate_answers( - user=user, commands=[command], message=message - )) + commands = await self.app_model.answer(message, user) or [] + + answers = self._generate_answers( + user=user, commands=commands, message=message + ) predefined_fields_resp = response.get("predefined_fields") if predefined_fields_resp: diff --git a/tests/core_tests/basic_scenario_models_test/action_test/test_action.py b/tests/core_tests/basic_scenario_models_test/action_test/test_action.py index dc0ee189..87e1edec 100644 --- a/tests/core_tests/basic_scenario_models_test/action_test/test_action.py +++ b/tests/core_tests/basic_scenario_models_test/action_test/test_action.py @@ -1,5 +1,4 @@ # coding: utf-8 -import inspect import json import unittest import uuid @@ -64,7 +63,7 @@ def __init__(self, items=None): self.result = items.get("result") async def run(self, user, text_preprocessing_result, params=None): - yield self.result or "test action run" + return self.result or ["test action run"] class UserMockAction: @@ -94,33 +93,6 @@ def collect(self, text_preprocessing_result, filter_params=None): return self.data -class AsyncIterator: - def __init__(self, *args, seq=None, **kwargs): - self.called = False - self.iter = iter(seq or list()) - for key, val in kwargs.items(): - self.__setattr__(key, val) - - def __aiter__(self): - self.called = True - return self - - async def __anext__(self): - try: - return next(self.iter) - except StopIteration: - raise StopAsyncIteration - - def __call__(self, *args, **kwargs): - return self - - def assert_called_once(self): - assert self.called - - def assert_not_called(self): - assert not self.called - - class ActionTest(unittest.IsolatedAsyncioTestCase): def test_nodes_1(self): items = {"nodes": {"answer": "test"}} @@ -140,8 +112,7 @@ async def test_base(self): items = {"nodes": "test"} action = Action(items) try: - async for command in action.run(None, None): - pass + await action.run(None, None) result = False except NotImplementedError: result = True @@ -152,22 +123,17 @@ async def test_external(self): action = ExternalAction(items) user = PicklableMock() user.descriptions = {"external_actions": {"test_action_key": MockAction()}} - result = [] - async for command in action.run(user, None): - result.append(command) - self.assertEqual(result, ["test action run"]) + self.assertEqual(await action.run(user, None), ["test action run"]) async def test_doing_nothing_action(self): items = {"nodes": {"answer": "test"}, "command": "test_name"} action = DoingNothingAction(items) - self.assertTrue(inspect.isasyncgen(action.run(None, None))) - result = [] - async for command in action.run(None, None): - result.append(command) + result = await action.run(None, None) + self.assertIsInstance(result, list) command = result[0] self.assertIsInstance(command, Command) - self.assertEqual("test_name", command.name) - self.assertEqual({"answer": "test"}, command.payload) + self.assertEqual(command.name, "test_name") + self.assertEqual(command.payload, {"answer": "test"}) async def test_requirement_action(self): requirements["test"] = MockRequirement @@ -176,15 +142,10 @@ async def test_requirement_action(self): action = RequirementAction(items) self.assertIsInstance(action.requirement, MockRequirement) self.assertIsInstance(action.internal_item, MockAction) - result = [] - async for command in action.run(None, None): - result.append(command) - self.assertEqual(result, ["test action run"]) + self.assertEqual(await action.run(None, None), ["test action run"]) items = {"requirement": {"type": "test", "result": False}, "action": {"type": "test"}} action = RequirementAction(items) - result = [] - async for command in action.run(None, None): - result.append(command) + result = await action.run(None, None) self.assertEqual(result, []) async def test_requirement_choice(self): @@ -195,10 +156,8 @@ async def test_requirement_choice(self): choice_action = ChoiceAction(items) self.assertIsInstance(choice_action.items, list) self.assertIsInstance(choice_action.items[0], RequirementAction) - result = [] - async for command in choice_action.run(None, None): - result.append(command) - self.assertEqual(result, [["action2"]]) + result = await choice_action.run(None, None) + self.assertEqual(result, ["action2"]) async def test_requirement_choice_else(self): items = { @@ -211,10 +170,8 @@ async def test_requirement_choice_else(self): choice_action = ChoiceAction(items) self.assertIsInstance(choice_action.items, list) self.assertIsInstance(choice_action.items[0], RequirementAction) - result = [] - async for command in choice_action.run(None, None): - result.append(command) - self.assertEqual(result, [["action3"]]) + result = await choice_action.run(None, None) + self.assertEqual(result, ["action3"]) async def test_string_action(self): expected = [Command("cmd_id", {"item": "template", "params": "params"})] @@ -226,9 +183,7 @@ async def test_string_action(self): user.parametrizer = MockSimpleParametrizer(user, {"data": params}) items = {"command": "cmd_id", "nodes": {"item": "template", "params": "{{params}}"}} action = StringAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual(expected[0].name, result[0].name) self.assertEqual(expected[0].payload, result[0].payload) @@ -244,10 +199,7 @@ async def test_else_action_if(self): "else_action": {"type": "test", "result": ["else_action"]}, } action = ElseAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) - self.assertEqual(result, [["main_action"]]) + self.assertEqual(await action.run(user, None), ["main_action"]) async def test_else_action_else(self): registered_factories[Requirement] = requirement_factory @@ -261,10 +213,7 @@ async def test_else_action_else(self): "else_action": {"type": "test", "result": ["else_action"]}, } action = ElseAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) - self.assertEqual(result, [["else_action"]]) + self.assertEqual(await action.run(user, None), ["else_action"]) async def test_else_action_no_else_if(self): registered_factories[Requirement] = requirement_factory @@ -277,10 +226,7 @@ async def test_else_action_no_else_if(self): "action": {"type": "test", "result": ["main_action"]}, } action = ElseAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) - self.assertEqual(result, [["main_action"]]) + self.assertEqual(await action.run(user, None), ["main_action"]) async def test_else_action_no_else_else(self): registered_factories[Requirement] = requirement_factory @@ -293,10 +239,8 @@ async def test_else_action_no_else_else(self): "action": {"type": "test", "result": ["main_action"]}, } action = ElseAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) - self.assertEqual(result, []) + result = await action.run(user, None) + self.assertEqual([], result) async def test_composite_action(self): registered_factories[Action] = action_factory @@ -304,9 +248,7 @@ async def test_composite_action(self): user = PicklableMock() items = {"actions": [{"type": "action_mock"}, {"type": "action_mock"}]} action = CompositeAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual(['test action run', 'test action run'], result) async def test_node_action_support_templates(self): @@ -331,10 +273,8 @@ async def test_node_action_support_templates(self): self.assertIsInstance(template, UnifiedTemplate) user = PicklableMagicMock() user.parametrizer = MockSimpleParametrizer(user, {"data": params}) - result = [] - async for command in action.run(user, None): - result.append(command) - self.assertEqual(expected, result[0].payload["answer"]) + output = (await action.run(user=user, text_preprocessing_result=None))[0].payload["answer"] + self.assertEqual(output, expected) async def test_string_action_support_templates(self): params = { @@ -355,10 +295,8 @@ async def test_string_action_support_templates(self): action = StringAction(items) user = PicklableMagicMock() user.parametrizer = MockSimpleParametrizer(user, {"data": params}) - result = [] - async for command in action.run(user, None): - result.append(command) - self.assertEqual(expected, result[0].payload) + output = (await action.run(user=user, text_preprocessing_result=None))[0].payload + self.assertEqual(output, expected) async def test_push_action(self): params = { @@ -400,10 +338,7 @@ async def test_push_action(self): action = PushAction(items) user.parametrizer = MockSimpleParametrizer(user, {"data": params}) user.settings = settings - result = [] - async for command in action.run(user, None): - result.append(command) - command = result[0] + command = (await action.run(user=user, text_preprocessing_result=None))[0] self.assertEqual(command.raw, expected) # проверяем наличие кастомных хэдеров для сервиса пушей self.assertTrue(SmartKitKafkaRequest.KAFKA_EXTRA_HEADERS in command.request_data) @@ -450,10 +385,8 @@ async def test_get_runtime_permissions(self): user.parametrizer = MockSimpleParametrizer(user, {"data": params}) user.settings = settings text_preprocessing_result = BaseTextPreprocessingResult(items) - result = [] - async for command in action.run(user=user, text_preprocessing_result=text_preprocessing_result): - result.append(command) - self.assertEqual(result[0].raw, expected) + command = (await action.run(user=user, text_preprocessing_result=text_preprocessing_result))[0] + self.assertEqual(command.raw, expected) def test_push_action_http_with_apprequest_lite_type_request(self): items = { @@ -615,7 +548,7 @@ async def test_push_authentication_action_http_call(self, request_mock: Mock): parametrizer=Mock(collect=lambda *args, **kwargs: {}), descriptions={ "behaviors": { - "common_behavior": Mock(timeout=Mock(return_value=4), success_action="") + "common_behavior": AsyncMock(timeout=Mock(return_value=4)) } } ) @@ -642,8 +575,7 @@ async def test_push_authentication_action_http_call(self, request_mock: Mock): "scope": "SMART_PUSH" } http_request_action.method_params["json"] = request_body_parameters - async for command in http_request_action.run(user, None, None): - pass + await http_request_action.run(user, None, None) request_mock.assert_called_with( url="https://salute.online.sberbank.ru:9443/api/v2/oauth", headers={ @@ -660,7 +592,7 @@ async def test_push_action_http_call_with_apprequest_lite_type_request(self, req parametrizer=Mock(collect=lambda *args, **kwargs: {}), descriptions={ "behaviors": { - "common_behavior": Mock(timeout=Mock(return_value=4), success_action="") + "common_behavior": AsyncMock(timeout=Mock(return_value=4)) } } ) @@ -717,8 +649,7 @@ async def test_push_action_http_call_with_apprequest_lite_type_request(self, req } } http_request_action.method_params["json"] = request_body_parameters - async for command in http_request_action.run(user, None, None): - pass + await http_request_action.run(user, None, None) request_mock.assert_called_with( url="https://salute.online.sberbank.ru:9443/api/v2/smartpush/apprequest-lite", headers={ @@ -735,7 +666,7 @@ async def test_push_action_http_call_with_apprequest_type_request(self, request_ parametrizer=Mock(collect=lambda *args, **kwargs: {}), descriptions={ "behaviors": { - "common_behavior": Mock(timeout=Mock(return_value=4), success_action="") + "common_behavior": AsyncMock(timeout=Mock(return_value=4)) } } ) @@ -902,8 +833,7 @@ async def test_push_action_http_call_with_apprequest_type_request(self, request_ } } http_request_action.method_params["json"] = request_body_parameters - async for command in http_request_action.run(user, None, None): - pass + await http_request_action.run(user, None, None) request_mock.assert_called_with( url="https://salute.online.sberbank.ru:9443/api/v2/smartpush/apprequest", headers={ @@ -933,16 +863,14 @@ def setUp(self): async def test_run_available_indexes(self): self.user.last_action_ids["last_action_ids_storage"].get_list.side_effect = [[0]] - async for command in self.action.run(self.user, None): - self.assertEqual(self.expected1, command) + result = await self.action.run(self.user, None) self.user.last_action_ids["last_action_ids_storage"].add.assert_called_once() + self.assertEqual(result, self.expected1) async def test_run_no_available_indexes(self): self.user.last_action_ids["last_action_ids_storage"].get_list.side_effect = [[0, 1]] - result = [] - async for command in self.action.run(self.user, None): - result.append(command) - self.assertEqual(result, [self.expected]) + result = await self.action.run(self.user, None) + self.assertEqual(result, self.expected) class CounterIncrementActionTest(unittest.IsolatedAsyncioTestCase): @@ -953,9 +881,7 @@ async def test_run(self): user.counters = {"test": counter} items = {"key": "test"} action = CounterIncrementAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + await action.run(user, None) user.counters["test"].inc.assert_called_once() @@ -967,9 +893,7 @@ async def test_run(self): user.counters = {"test": counter} items = {"key": "test"} action = CounterDecrementAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + await action.run(user, None) user.counters["test"].dec.assert_called_once() @@ -980,9 +904,7 @@ async def test_run(self): user.counters.inc = PicklableMock() items = {"key": "test"} action = CounterClearAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + await action.run(user, None) user.counters.clear.assert_called_once() @@ -995,9 +917,7 @@ async def test_run(self): user.counters = counters items = {"key": "test"} action = CounterSetAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + await action.run(user, None) user.counters["test"].set.assert_called_once() @@ -1010,9 +930,7 @@ async def test_run(self): user.counters = {"src": counter_src, "dst": counter_dst} items = {"source": "src", "destination": "dst"} action = CounterCopyAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + await action.run(user, None) user.counters["dst"].set.assert_called_once_with(user.counters["src"].value, action.reset_time, action.time_shift) @@ -1030,9 +948,7 @@ async def test_typical_answer(self): } } action = AfinaAnswerAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual(expected[0]._name, result[0].name) self.assertEqual(expected[0].raw, result[0].raw) @@ -1052,9 +968,7 @@ async def test_typical_answer_with_other(self): "nodes": {"answer": ["a1", "a1", "a1"], "pronounce_text": ["pt2"], "picture": ["1.jpg", "1.jpg", "1.jpg"]} } action = AfinaAnswerAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual(expected[0]._name, result[0].name) self.assertEqual(expected[0].raw, result[0].raw) @@ -1067,9 +981,7 @@ async def test_typical_answer_with_pers_info(self): user.message.payload = {"personInfo": {"name": "Ivan Ivanov"}} items = {"nodes": {"answer": ["{{payload.personInfo.name}}"]}} action = AfinaAnswerAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual(expected[0]._name, result[0].name) self.assertEqual(expected[0].raw, result[0].raw) @@ -1081,9 +993,7 @@ async def test_items_empty(self): user.descriptions = {"render_templates": template} items = None action = AfinaAnswerAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual(result, []) async def test__items_empty_dict(self): @@ -1094,9 +1004,7 @@ async def test__items_empty_dict(self): user.descriptions = {"render_templates": template} items = {} action = AfinaAnswerAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual(result, []) @@ -1250,9 +1158,7 @@ async def test_typical_answer(self): for i in range(10): action = SDKAnswer(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual("ANSWER_TO_USER", result[0].name) self.assertTrue(json.dumps(result[0].raw, sort_keys=True) in expect_arr) self.assertFalse(json.dumps(result[0].raw, sort_keys=True) in not_expect_arr) @@ -1277,9 +1183,7 @@ async def test_typical_answer_without_items(self): exp_list = [exp1, exp2] for i in range(10): action = SDKAnswer(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual("ANSWER_TO_USER", result[0].name) self.assertTrue(json.dumps((result[0].raw), sort_keys=True) in exp_list) @@ -1366,9 +1270,7 @@ async def test_typical_answer_without_nodes(self): not_expect_arr = [nexp1, nexp2] for i in range(10): action = SDKAnswer(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual("ANSWER_TO_USER", result[0].name) self.assertTrue(json.dumps((result[0].raw), sort_keys=True) in expect_arr) self.assertFalse(json.dumps((result[0].raw), sort_keys=True) in not_expect_arr) @@ -1468,9 +1370,7 @@ async def test_SDKItemAnswer_full(self): action = SDKAnswerToUser(items) for i in range(3): - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertTrue(json.dumps(result[0].raw, sort_keys=True) in [exp1, exp2]) async def test_SDKItemAnswer_root(self): @@ -1505,9 +1405,7 @@ async def test_SDKItemAnswer_root(self): action = SDKAnswerToUser(items) for i in range(3): - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertTrue(json.dumps(result[0].raw, sort_keys=True) in [exp1, exp2]) async def test_SDKItemAnswer_simple(self): @@ -1519,9 +1417,7 @@ async def test_SDKItemAnswer_simple(self): user.parametrizer = MockParametrizer(user, {}) items = {"type": "sdk_answer_to_user", "items": [{"type": "bubble_text", "text": "42"}]} action = SDKAnswerToUser(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertDictEqual( result[0].raw, {'messageName': 'ANSWER_TO_USER', 'payload': {'items': [{'bubble': {'text': '42', 'markdown': True}}]}} @@ -1547,9 +1443,7 @@ async def test_SDKItemAnswer_suggestions_template(self): }, } action = SDKAnswerToUser(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertDictEqual( result[0].raw, { @@ -1602,9 +1496,7 @@ async def test_run(self, settings_mock: MagicMock): } text_preprocessing_result = PicklableMock() action = GiveMeMemoryAction(items) - result = [] - async for command in action.run(user, text_preprocessing_result): - result.append(command) + result = await action.run(user, text_preprocessing_result) self.assertEqual(expected[0].name, result[0].name) self.assertEqual(expected[0].payload, result[0].payload) @@ -1692,8 +1584,6 @@ async def test_run(self): } } action = RememberThisAction(items) - result = [] - async for command in action.run(user, None): - result.append(command) + result = await action.run(user, None) self.assertEqual(expected[0].name, result[0].name) self.assertEqual(expected[0].payload, result[0].payload) diff --git a/tests/core_tests/basic_scenario_models_test/action_test/test_random_action.py b/tests/core_tests/basic_scenario_models_test/action_test/test_random_action.py index f9bac443..2449c9a0 100644 --- a/tests/core_tests/basic_scenario_models_test/action_test/test_random_action.py +++ b/tests/core_tests/basic_scenario_models_test/action_test/test_random_action.py @@ -31,9 +31,7 @@ async def test_1(self): ] } action = RandomAction(items, 5) - result = [] - async for command in action.run(None, None): - result.append(command) + result = await action.run(None, None) self.assertIsNotNone(result) async def test_2(self): @@ -49,7 +47,5 @@ async def test_2(self): ] } action = RandomAction(items, 5) - result = [] - async for command in action.run(None, None): - result.append(command) + result = await action.run(None, None) self.assertIsNotNone(result) diff --git a/tests/core_tests/basic_scenario_models_test/action_test/test_smartpay.py b/tests/core_tests/basic_scenario_models_test/action_test/test_smartpay.py index 7668a767..8174b572 100644 --- a/tests/core_tests/basic_scenario_models_test/action_test/test_smartpay.py +++ b/tests/core_tests/basic_scenario_models_test/action_test/test_smartpay.py @@ -1,12 +1,11 @@ import unittest -from unittest.mock import patch, MagicMock, AsyncMock, Mock +from unittest.mock import patch, MagicMock, AsyncMock from aiohttp import ClientTimeout from core.basic_models.actions.smartpay import SmartPayCreateAction, SmartPayPerformAction, SmartPayGetStatusAction, \ SmartPayConfirmAction, SmartPayDeleteAction, SmartPayRefundAction from smart_kit.utils.picklable_mock import PicklableMock, AsyncPicklableMock -from tests.core_tests.basic_scenario_models_test.action_test.test_action import AsyncIterator from tests.smart_kit_tests.action.test_http import HttpRequestActionTest @@ -16,8 +15,7 @@ def setUp(self): parametrizer=PicklableMock(collect=lambda *args, **kwargs: {}), descriptions={ "behaviors": { - "my_behavior": Mock(timeout=PicklableMock(return_value=3), - success_action=PicklableMock(run=AsyncIterator([]))) + "my_behavior": AsyncMock(timeout=PicklableMock(return_value=3)) } } ) @@ -147,8 +145,7 @@ async def test_create(self, request_mock: PicklableMock, settings_mock: MagicMoc settings_mock.return_value = self.settings HttpRequestActionTest.set_request_mock_attribute(request_mock, return_value={'data': 'value'}) action = SmartPayCreateAction(items) - async for command in action.run(self.user, None, {}): - pass + await action.run(self.user, None, {}) request_mock.assert_called_with(url="0.0.0.0/invoices", method="POST", timeout=ClientTimeout(3), json=items.get("smartpay_params")) self.assertTrue(self.user.descriptions["behaviors"]["my_behavior"].success_action.run.called) @@ -189,8 +186,7 @@ async def test_perform(self, request_mock: PicklableMock, settings_mock: MagicMo settings_mock.return_value = self.settings HttpRequestActionTest.set_request_mock_attribute(request_mock, return_value={'data': 'value'}) action = SmartPayPerformAction(items) - async for command in action.run(self.user, None, {}): - pass + await action.run(self.user, None, {}) request_mock.assert_called_with(url="0.0.0.0/invoices/0", method="POST", timeout=ClientTimeout(3), json=items.get("smartpay_params")) self.assertTrue(self.user.descriptions["behaviors"]["my_behavior"].success_action.run.called) @@ -210,8 +206,7 @@ async def test_get_status(self, request_mock: PicklableMock, settings_mock: Magi settings_mock.return_value = self.settings HttpRequestActionTest.set_request_mock_attribute(request_mock, return_value={'data': 'value'}) action = SmartPayGetStatusAction(items) - async for command in action.run(self.user, None, {}): - pass + await action.run(self.user, None, {}) request_mock.assert_called_with(url="0.0.0.0/invoices/0", method="GET", timeout=ClientTimeout(3), params={"inv_status": "executed", "wait": 50}) self.assertTrue(self.user.descriptions["behaviors"]["my_behavior"].success_action.run.called) @@ -251,8 +246,7 @@ async def test_partial_confirm(self, request_mock: PicklableMock, settings_mock: settings_mock.return_value = self.settings HttpRequestActionTest.set_request_mock_attribute(request_mock, return_value={'data': 'value'}) action = SmartPayConfirmAction(items) - async for command in action.run(self.user, None, {}): - pass + await action.run(self.user, None, {}) request_mock.assert_called_with(url="0.0.0.0/invoices/0", method="PUT", timeout=ClientTimeout(3), json=items["smartpay_params"]) self.assertTrue(self.user.descriptions["behaviors"]["my_behavior"].success_action.run.called) @@ -270,8 +264,7 @@ async def test_full_confirm(self, request_mock: PicklableMock, settings_mock: Ma settings_mock.return_value = self.settings HttpRequestActionTest.set_request_mock_attribute(request_mock, return_value={'data': 'value'}) action = SmartPayConfirmAction(items) - async for command in action.run(self.user, None, {}): - pass + await action.run(self.user, None, {}) request_mock.assert_called_with(url="0.0.0.0/invoices/0", method="PUT", timeout=ClientTimeout(3)) self.assertTrue(self.user.descriptions["behaviors"]["my_behavior"].success_action.run.called) self.assertTrue(self.user.variables.set.called) @@ -288,8 +281,7 @@ async def test_delete(self, request_mock: PicklableMock, settings_mock: MagicMoc settings_mock.return_value = self.settings HttpRequestActionTest.set_request_mock_attribute(request_mock, return_value={'data': 'value'}) action = SmartPayDeleteAction(items) - async for command in action.run(self.user, None, {}): - pass + await action.run(self.user, None, {}) request_mock.assert_called_with(url="0.0.0.0/invoices/0", method="DELETE", timeout=ClientTimeout(3)) self.assertTrue(self.user.descriptions["behaviors"]["my_behavior"].success_action.run.called) self.assertTrue(self.user.variables.set.called) @@ -326,8 +318,7 @@ async def test_partial_refund(self, request_mock: PicklableMock, settings_mock: settings_mock.return_value = self.settings HttpRequestActionTest.set_request_mock_attribute(request_mock, return_value={'data': 'value'}) action = SmartPayRefundAction(items) - async for command in action.run(self.user, None, {}): - pass + await action.run(self.user, None, {}) request_mock.assert_called_with(url="0.0.0.0/invoices/0", method="PATCH", timeout=ClientTimeout(3), json=items["smartpay_params"]) self.assertTrue(self.user.descriptions["behaviors"]["my_behavior"].success_action.run.called) @@ -345,8 +336,7 @@ async def test_full_refund(self, request_mock: PicklableMock, settings_mock: Mag settings_mock.return_value = self.settings HttpRequestActionTest.set_request_mock_attribute(request_mock, return_value={'data': 'value'}) action = SmartPayRefundAction(items) - async for command in action.run(self.user, None, {}): - pass + await action.run(self.user, None, {}) request_mock.assert_called_with(url="0.0.0.0/invoices/0", method="PATCH", timeout=ClientTimeout(3)) self.assertTrue(self.user.descriptions["behaviors"]["my_behavior"].success_action.run.called) self.assertTrue(self.user.variables.set.called) diff --git a/tests/scenarios_tests/actions_test/test_action.py b/tests/scenarios_tests/actions_test/test_action.py index 7cf889ca..74b4a60e 100644 --- a/tests/scenarios_tests/actions_test/test_action.py +++ b/tests/scenarios_tests/actions_test/test_action.py @@ -25,7 +25,6 @@ from smart_kit.action.smart_geo_action import SmartGeoAction from smart_kit.message.smartapp_to_message import SmartAppToMessage from smart_kit.utils.picklable_mock import PicklableMock, PicklableMagicMock, AsyncPicklableMock -from tests.core_tests.basic_scenario_models_test.action_test.test_action import AsyncIterator from tests.core_tests.requirements_test.test_requirements import MockRequirement @@ -35,8 +34,7 @@ def __init__(self, items=None): self.result = items.get("result") async def run(self, user, text_preprocessing_result, params=None): - for command in self.result or ["test action run"]: - yield command + return self.result or ["test action run"] class MockParametrizer: @@ -60,8 +58,7 @@ class ClearFormIdActionTest(unittest.IsolatedAsyncioTestCase): async def test_run(self): action = ClearFormAction({"form": "form"}) user = PicklableMagicMock() - async for command in action.run(user, None): - pass + await action.run(user, None) user.forms.remove_item.assert_called_once_with("form") @@ -70,8 +67,7 @@ async def test_run(self): action = ClearInnerFormAction({"form": "form", "inner_form": "inner_form"}) user, form = PicklableMagicMock(), PicklableMagicMock() user.forms.__getitem__.return_value = form - async for command in action.run(user, None): - pass + await action.run(user, None) form.forms.remove_item.assert_called_once_with("inner_form") @@ -83,8 +79,7 @@ async def test_run_1(self): scenario_model = PicklableMagicMock() scenario_model.set_break = Mock(return_value=None) user.scenario_models = {scenario_id: scenario_model} - async for command in action.run(user, None): - pass + await action.run(user, None) user.scenario_models[scenario_id].set_break.assert_called_once() async def test_run_2(self): @@ -95,8 +90,7 @@ async def test_run_2(self): scenario_model = PicklableMagicMock() scenario_model.set_break = Mock(return_value=None) user.scenario_models = {scenario_id: scenario_model} - async for command in action.run(user, None): - pass + await action.run(user, None) user.scenario_models[scenario_id].set_break.assert_called_once() @@ -105,8 +99,7 @@ async def test_run(self): action = RemoveFormFieldAction({"form": "form", "field": "field"}) user, form = PicklableMagicMock(), PicklableMagicMock() user.forms.__getitem__.return_value = form - async for command in action.run(user, None): - pass + await action.run(user, None) form.fields.remove_item.assert_called_once_with("field") @@ -116,8 +109,7 @@ async def test_run(self): user, inner_form, form = PicklableMagicMock(), PicklableMagicMock(), PicklableMagicMock() form.forms.__getitem__.return_value = inner_form user.forms.__getitem__.return_value = form - async for command in action.run(user, None): - pass + await action.run(user, None) inner_form.fields.remove_item.assert_called_once_with("field") @@ -140,8 +132,7 @@ async def test_save_behavior_scenario_name(self): action = SaveBehaviorAction(data) tpr = PicklableMock() tpr_raw = tpr.raw - async for command in action.run(self.user, tpr): - pass + await action.run(self.user, tpr) self.user.behaviors.add.assert_called_once_with(self.user.message.generate_new_callback_id(), "test", self.user.last_scenarios.last_scenario_name, tpr_raw, action_params=None) @@ -154,8 +145,7 @@ async def test_save_behavior_without_scenario_name(self): action = SaveBehaviorAction(data) text_preprocessing_result_raw = PicklableMock() text_preprocessing_result = Mock(raw=text_preprocessing_result_raw) - async for command in action.run(self.user, text_preprocessing_result, None): - pass + await action.run(self.user, text_preprocessing_result, None) self.user.behaviors.add.assert_called_once_with(self.user.message.generate_new_callback_id(), "test", None, text_preprocessing_result_raw, action_params=None) @@ -183,9 +173,7 @@ async def test_action_1(self): action = SelfServiceActionWithState(data) text_preprocessing_result_raw = PicklableMock() text_preprocessing_result = Mock(raw=text_preprocessing_result_raw) - result = [] - async for command in action.run(self.user, text_preprocessing_result, None): - result.append(command) + result = await action.run(self.user, text_preprocessing_result, None) behavior.check_got_saved_id.assert_called_once() behavior.add.assert_called_once() self.assertEqual(result[0].name, "cmd_id") @@ -201,9 +189,7 @@ async def test_action_2(self): self.user.behaviors = behavior behavior.check_got_saved_id = Mock(return_value=True) action = SelfServiceActionWithState(data) - result = [] - async for command in action.run(self.user, None): - result.append(command) + result = await action.run(self.user, None) behavior.add.assert_not_called() self.assertEqual(result, []) @@ -232,9 +218,7 @@ async def test_action_3(self): action = SelfServiceActionWithState(data) text_preprocessing_result_raw = PicklableMock() text_preprocessing_result = Mock(raw=text_preprocessing_result_raw) - result = [] - async for command in action.run(self.user, text_preprocessing_result, None): - result.append(command) + result = await action.run(self.user, text_preprocessing_result, None) behavior.check_got_saved_id.assert_called_once() behavior.add.assert_called_once() self.assertEqual(result[0].name, "cmd_id") @@ -262,22 +246,19 @@ def setUp(self): async def test_action(self): action = SetVariableAction({"key": "some_key", "value": "some_value"}) - async for command in action.run(self.user, None): - pass + await action.run(self.user, None) self.user.variables.set.assert_called_with("some_key", "some_value", None) async def test_action_jinja_key_default(self): self.user.message.payload = {"some_value": "some_value_test"} action = SetVariableAction({"key": "some_key", "value": "{{payload.some_value}}"}) - async for command in action.run(self.user, None): - pass + await action.run(self.user, None) self.user.variables.set.assert_called_with("some_key", "some_value_test", None) async def test_action_jinja_no_key(self): self.user.message.payload = {"some_value": "some_value_test"} action = SetVariableAction({"key": "some_key", "value": "{{payload.no_key}}"}) - async for command in action.run(self.user, None): - pass + await action.run(self.user, None) self.user.variables.set.assert_called_with("some_key", "", None) @@ -295,8 +276,7 @@ def setUp(self): async def test_action(self): action = DeleteVariableAction({"key": "some_key_1"}) - async for command in action.run(self.user, None): - pass + await action.run(self.user, None) self.user.variables.delete.assert_called_with("some_key_1") @@ -318,8 +298,7 @@ def setUp(self): async def test_action(self): action = ClearVariablesAction() - async for command in action.run(self.user, None): - pass + await action.run(self.user, None) self.user.variables.clear.assert_called_with() @@ -335,8 +314,7 @@ async def test_fill_field(self): field = PicklableMock() field.fill = PicklableMock() user.forms["test_form"].fields = {"test_field": field} - async for command in action.run(user, None): - pass + await action.run(user, None) field.fill.assert_called_once_with(params["test_field"]) @@ -356,8 +334,7 @@ async def test_fill_field(self): field = PicklableMock() field.fill = PicklableMock() user.forms["test_form"].forms["test_internal_form"].fields = {"test_field": field} - async for command in action.run(user, None): - pass + await action.run(user, None) field.fill.assert_called_once_with(params["test_field"]) @@ -368,11 +345,9 @@ async def test_scenario_action(self): user.parametrizer = MockParametrizer(user, {}) scen = AsyncPicklableMock() scen_result = ['done'] - scen.run = AsyncIterator(seq=scen_result) + scen.run.return_value = scen_result user.descriptions = {"scenarios": {"test": scen}} - result = [] - async for command in action.run(user, PicklableMock()): - result.append(command) + result = await action.run(user, PicklableMock()) self.assertEqual(result, scen_result) async def test_scenario_action_with_jinja_good(self): @@ -384,11 +359,9 @@ async def test_scenario_action_with_jinja_good(self): user.parametrizer = MockParametrizer(user, {"data": params}) scen = AsyncPicklableMock() scen_result = ['done'] - scen.run = AsyncIterator(seq=scen_result) + scen.run.return_value = scen_result user.descriptions = {"scenarios": {"ANNA.pipeline.scenario": scen}} - result = [] - async for command in action.run(user, PicklableMock()): - result.append(command) + result = await action.run(user, PicklableMock()) self.assertEqual(result, scen_result) async def test_scenario_action_no_scenario(self): @@ -399,9 +372,7 @@ async def test_scenario_action_no_scenario(self): scen_result = 'done' scen.run.return_value = scen_result user.descriptions = {"scenarios": {"next_scenario": scen}} - result = [] - async for command in action.run(user, PicklableMock()): - result.append(command) + result = await action.run(user, PicklableMock()) self.assertEqual(result, []) async def test_scenario_action_without_jinja(self): @@ -410,11 +381,9 @@ async def test_scenario_action_without_jinja(self): user.parametrizer = MockParametrizer(user, {}) scen = AsyncPicklableMock() scen_result = ['done'] - scen.run = AsyncIterator(seq=scen_result) + scen.run.return_value = scen_result user.descriptions = {"scenarios": {"next_scenario": scen}} - result = [] - async for command in action.run(user, PicklableMock()): - result.append(command) + result = await action.run(user, PicklableMock()) self.assertEqual(result, scen_result) @@ -425,15 +394,13 @@ async def test_scenario_action(self): user = PicklableMock() scen = AsyncPicklableMock() scen_result = ['done'] - scen.run = AsyncIterator(seq=scen_result) + scen.run.return_value = scen_result user.descriptions = {"scenarios": {"test": scen}} user.last_scenarios = PicklableMock() last_scenario_name = "test" user.last_scenarios.scenarios_names = [last_scenario_name] user.last_scenarios.last_scenario_name = last_scenario_name - result = [] - async for command in action.run(user, PicklableMock()): - result.append(command) + result = await action.run(user, PicklableMock()) self.assertEqual(result, scen_result) @@ -456,13 +423,10 @@ async def mock_and_perform_action(test_items: Dict[str, Any], expected_result: O user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) scen = AsyncPicklableMock() - scen.run = AsyncIterator(seq=expected_result) + scen.run.return_value = expected_result if expected_scen: user.descriptions = {"scenarios": {expected_scen: scen}} - result = [] - async for command in action.run(user, PicklableMock()): - result.append(command) - return result + return await action.run(user, PicklableMock()) async def test_choice_scenario_action(self): # Проверяем, что запустили нужный сценарий, в случае если выполнился его requirement @@ -540,9 +504,7 @@ async def test_action(self): user.descriptions = {"scenarios": {scenario_name: scenario}} action = ClearCurrentScenarioAction({}) - result = [] - async for command in action.run(user, {}, {}): - result.append(command) + result = await action.run(user, {}, {}) self.assertEqual([], result) user.last_scenarios.delete.assert_called_once() user.forms.remove_item.assert_called_once() @@ -555,9 +517,7 @@ async def test_action_with_empty_scenarios_names(self): user.last_scenarios.delete = PicklableMock() action = ClearCurrentScenarioAction({}) - result = [] - async for command in action.run(user, {}, {}): - result.append(command) + result = await action.run(user, {}, {}) self.assertEqual([], result) user.last_scenarios.delete.assert_not_called() user.forms.remove_item.assert_not_called() @@ -576,9 +536,7 @@ async def test_action(self): user.descriptions = {"scenarios": {scenario_name: scenario}} action = ClearScenarioByIdAction({"scenario_id": scenario_name}) - result = [] - async for command in action.run(user, {}, {}): - result.append(command) + result = await action.run(user, {}, {}) self.assertEqual([], result) user.last_scenarios.delete.assert_called_once() user.forms.remove_item.assert_called_once() @@ -590,9 +548,7 @@ async def test_action_with_empty_scenarios_names(self): user.last_scenarios.last_scenario_name = "test_scenario" action = ClearScenarioByIdAction({}) - result = [] - async for command in action.run(user, {}, {}): - result.append(command) + result = await action.run(user, {}, {}) self.assertEqual([], result) user.last_scenarios.delete.assert_not_called() user.forms.remove_item.assert_not_called() @@ -612,9 +568,7 @@ async def test_action(self): user.descriptions = {"scenarios": {scenario_name: scenario}} action = ClearCurrentScenarioFormAction({}) - result = [] - async for command in action.run(user, {}, {}): - result.append(command) + result = await action.run(user, {}, {}) self.assertEqual([], result) user.forms.clear_form.assert_called_once() @@ -631,9 +585,7 @@ async def test_action_with_empty_last_scenario(self): user.descriptions = {"scenarios": {scenario_name: scenario}} action = ClearCurrentScenarioFormAction({}) - result = [] - async for command in action.run(user, {}, {}): - result.append(command) + result = await action.run(user, {}, {}) self.assertEqual([], result) user.forms.remove_item.assert_not_called() @@ -648,9 +600,7 @@ async def test_action(self): user.scenario_models = {'test_scenario': scenario_model} action = ResetCurrentNodeAction({}) - result = [] - async for command in action.run(user, {}, {}): - result.append(command) + result = await action.run(user, {}, {}) self.assertEqual([], result) self.assertIsNone(user.scenario_models['test_scenario'].current_node) @@ -663,9 +613,7 @@ async def test_action_with_empty_last_scenario(self): user.scenario_models = {'test_scenario': scenario_model} action = ResetCurrentNodeAction({}) - result = [] - async for command in action.run(user, {}, {}): - result.append(command) + result = await action.run(user, {}, {}) self.assertEqual([], result) self.assertEqual('some_node', user.scenario_models['test_scenario'].current_node) @@ -681,9 +629,7 @@ async def test_specific_target(self): 'node_id': 'another_node' } action = ResetCurrentNodeAction(items) - result = [] - async for command in action.run(user, {}, {}): - result.append(command) + result = await action.run(user, {}, {}) self.assertEqual([], result) self.assertEqual('another_node', user.scenario_models['test_scenario'].current_node) @@ -722,8 +668,7 @@ async def test_action_with_non_empty_scenario(self): ) action = AddHistoryEventAction(items) - async for command in action.run(self.user, None, None): - pass + await action.run(self.user, None, None) self.user.history.add_event.assert_called_once() self.user.history.add_event.assert_called_once_with(expected) @@ -737,8 +682,7 @@ async def test_action_with_empty_scenario(self): } action = AddHistoryEventAction(items) - async for command in action.run(self.user, None, None): - pass + await action.run(self.user, None, None) self.user.history.add_event.assert_not_called() @@ -761,8 +705,7 @@ async def test_action_with_jinja(self): ) action = AddHistoryEventAction(items) - async for command in action.run(self.user, None, None): - pass + await action.run(self.user, None, None) self.user.history.add_event.assert_called_once() self.user.history.add_event.assert_called_once_with(expected) @@ -783,10 +726,7 @@ async def test_action_send_request(self): user = Mock() text_preprocessing_result = Mock() params = Mock() - result = [] - async for command in self.smart_geo_action.run(user, text_preprocessing_result, params): - result.append(command) - command = result[0] + command = (await self.smart_geo_action.run(user, text_preprocessing_result, params))[0] answer = SmartAppToMessage(command, incoming_message, None) expected = { "messageId": "1605196199186625000", diff --git a/tests/scenarios_tests/behaviors_test/test_behavior_model.py b/tests/scenarios_tests/behaviors_test/test_behavior_model.py index cf40f06e..d331f673 100644 --- a/tests/scenarios_tests/behaviors_test/test_behavior_model.py +++ b/tests/scenarios_tests/behaviors_test/test_behavior_model.py @@ -7,7 +7,6 @@ import scenarios.behaviors.behaviors from smart_kit.utils.picklable_mock import PicklableMock, AsyncPicklableMock -from tests.core_tests.basic_scenario_models_test.action_test.test_action import AsyncIterator class BehaviorsTest(unittest.IsolatedAsyncioTestCase): @@ -19,11 +18,9 @@ def setUp(self): self.description = PicklableMock() self.description.timeout = Mock(return_value=10) self.success_action = AsyncPicklableMock() - self.success_action.run = AsyncIterator() + self.success_action.run = AsyncPicklableMock() self.fail_action = AsyncPicklableMock() - self.fail_action.run = AsyncIterator() self.timeout_action = AsyncPicklableMock() - self.timeout_action.run = AsyncIterator() self.description.success_action = self.success_action self.description.fail_action = self.fail_action @@ -39,8 +36,7 @@ async def test_success(self): items = {str(callback_id): item} behaviors = scenarios.behaviors.behaviors.Behaviors(items, self.descriptions, self.user) behaviors.initialize() - async for command in behaviors.success(callback_id): - pass + await behaviors.success(callback_id) # self.success_action.run.assert_called_once_with(self.user, TextPreprocessingResult({})) self.success_action.run.assert_called_once() self.assertDictEqual(behaviors.raw, {}) @@ -50,8 +46,7 @@ async def test_success_2(self): items = {} behaviors = scenarios.behaviors.behaviors.Behaviors(items, self.descriptions, self.user) behaviors.initialize() - async for command in behaviors.success(callback_id): - pass + await behaviors.success(callback_id) self.success_action.run.assert_not_called() async def test_fail(self): @@ -61,8 +56,7 @@ async def test_fail(self): items = {str(callback_id): item} behaviors = scenarios.behaviors.behaviors.Behaviors(items, self.descriptions, self.user) behaviors.initialize() - async for command in behaviors.fail(callback_id): - pass + await behaviors.fail(callback_id) self.fail_action.run.assert_called_once() self.assertDictEqual(behaviors.raw, {}) @@ -73,8 +67,7 @@ async def test_timeout(self): items = {str(callback_id): item} behaviors = scenarios.behaviors.behaviors.Behaviors(items, self.descriptions, self.user) behaviors.initialize() - async for command in behaviors.timeout(callback_id): - pass + await behaviors.timeout(callback_id) self.timeout_action.run.assert_called_once() self.assertDictEqual(behaviors.raw, {}) diff --git a/tests/scenarios_tests/scenarios_test/test_tree_scenario.py b/tests/scenarios_tests/scenarios_test/test_tree_scenario.py index 4dca138f..e1b7dec1 100644 --- a/tests/scenarios_tests/scenarios_test/test_tree_scenario.py +++ b/tests/scenarios_tests/scenarios_test/test_tree_scenario.py @@ -86,9 +86,7 @@ async def test_1(self): scenario = TreeScenario(items, 1) - return # TODO remove after adapting tree scenario for async generator interfaces of actions - async for command in scenario.run(text_preprocessing_result, user): - pass + _ = await scenario.run(text_preprocessing_result, user) self.assertIsNone(current_node_mock.current_node) context_forms.new.assert_called_once_with(form_type) @@ -151,10 +149,7 @@ async def test_break(self): scenario = TreeScenario(items, 1) - return # TODO remove after adapting tree scenario for async generator interfaces of actions - result = [] - async for command in scenario.run(text_preprocessing_result, user): - result.append(command) + result = await scenario.run(text_preprocessing_result, user) self.assertFalse(scenario.actions[0].called) self.assertEqual(result[0].name, "break action result") diff --git a/tests/smart_kit_tests/action/test_http.py b/tests/smart_kit_tests/action/test_http.py index ae4fd888..12078c89 100644 --- a/tests/smart_kit_tests/action/test_http.py +++ b/tests/smart_kit_tests/action/test_http.py @@ -4,7 +4,6 @@ from aiohttp import ClientTimeout from smart_kit.action.http import HTTPRequestAction -from tests.core_tests.basic_scenario_models_test.action_test.test_action import AsyncIterator class HttpRequestActionTest(unittest.IsolatedAsyncioTestCase): @@ -15,8 +14,7 @@ def setUp(self): parametrizer=Mock(collect=lambda *args, **kwargs: {}), descriptions={ "behaviors": { - "my_behavior": Mock(timeout=Mock(return_value=3), - success_action=Mock(run=AsyncIterator())) + "my_behavior": AsyncMock(timeout=Mock(return_value=3)) } } ) @@ -45,8 +43,7 @@ async def test_simple_request(self, request_mock: Mock): "store": "user_variable", "behavior": "my_behavior", } - async for command in HTTPRequestAction(items).run(self.user, None, {}): - pass + await HTTPRequestAction(items).run(self.user, None, {}) request_mock.assert_called_with(url="https://my.url.com", method='POST', timeout=ClientTimeout(3)) self.assertTrue(self.user.descriptions["behaviors"]["my_behavior"].success_action.run.called) self.assertTrue(self.user.variables.set.called) @@ -71,8 +68,7 @@ async def test_render_params(self, request_mock: Mock): "url": "my.url.com", "value": "my_value" } - async for command in HTTPRequestAction(items).run(self.user, None, params): - pass + await HTTPRequestAction(items).run(self.user, None, params) request_mock.assert_called_with( url="https://my.url.com", method='POST', timeout=ClientTimeout(3), json={"param": "my_value"} ) @@ -93,8 +89,7 @@ async def test_headers_fix(self, request_mock): "store": "user_variable", "behavior": "my_behavior", } - async for command in HTTPRequestAction(items).run(self.user, None, {}): - pass + await HTTPRequestAction(items).run(self.user, None, {}) request_mock.assert_called_with(headers={ "header_1": "32", "header_2": "32.03", @@ -109,7 +104,6 @@ async def test_behavior_is_none(self, request_mock): }, "store": "user_variable", } - async for command in HTTPRequestAction(items).run(self.user, None, {}): - pass + await HTTPRequestAction(items).run(self.user, None, {}) request_mock.assert_called_with(method=HTTPRequestAction.DEFAULT_METHOD, timeout=ClientTimeout(HTTPRequestAction.DEFAULT_TIMEOUT)) diff --git a/tests/smart_kit_tests/action/test_run_scenario_by_project_name.py b/tests/smart_kit_tests/action/test_run_scenario_by_project_name.py index 73d7ad0c..9f769177 100644 --- a/tests/smart_kit_tests/action/test_run_scenario_by_project_name.py +++ b/tests/smart_kit_tests/action/test_run_scenario_by_project_name.py @@ -8,7 +8,7 @@ class TestScenarioDesc(dict): async def run(self, argv1, argv2, params): - yield 'result to run scenario' + return ['result to run scenario'] class RunScenarioByProjectNameActionTest1(unittest.IsolatedAsyncioTestCase): @@ -35,18 +35,12 @@ def setUp(self): async def test_run_scenario_by_project_name_run(self): obj1 = RunScenarioByProjectNameAction(self.items) # без оглядки на аннотации из PEP 484 - result = [] - async for command in obj1.run(self.test_user1, self.test_text_preprocessing_result, {'any_attr': {'any_data'}}): - result.append(command) self.assertEqual( - result, + await obj1.run(self.test_user1, self.test_text_preprocessing_result, {'any_attr': {'any_data'}}), ['result to run scenario'] ) obj2 = RunScenarioByProjectNameAction(self.items) - result = [] - async for command in obj2.run(self.test_user2, self.test_text_preprocessing_result): - result.append(command) - self.assertEqual([], result) + self.assertEqual([], await obj2.run(self.test_user2, self.test_text_preprocessing_result)) def test_run_scenario_by_project_name_log_vars(self): obj = RunScenarioByProjectNameAction(self.items) diff --git a/tests/smart_kit_tests/handlers/test_handle_close_app.py b/tests/smart_kit_tests/handlers/test_handle_close_app.py index 16b02175..76a3c318 100644 --- a/tests/smart_kit_tests/handlers/test_handle_close_app.py +++ b/tests/smart_kit_tests/handlers/test_handle_close_app.py @@ -40,7 +40,4 @@ def test_handler_close_app_init(self): async def test_handler_close_app_run(self): self.assertIsNotNone(handle_close_app.log_const.KEY_NAME) obj = handle_close_app.HandlerCloseApp(app_name=self.app_name) - result = [] - async for command in obj.run(self.test_payload, self.test_user): - result.append(command) - self.assertEqual([], result) + self.assertEqual([], await obj.run(self.test_payload, self.test_user)) diff --git a/tests/smart_kit_tests/handlers/test_handle_respond.py b/tests/smart_kit_tests/handlers/test_handle_respond.py index d08e36d2..98487885 100644 --- a/tests/smart_kit_tests/handlers/test_handle_respond.py +++ b/tests/smart_kit_tests/handlers/test_handle_respond.py @@ -7,7 +7,7 @@ async def mock_test_action_run(x, y, z): - yield 10 + return [10] class HandlerTest4(unittest.IsolatedAsyncioTestCase): @@ -74,9 +74,5 @@ async def test_handler_respond_run(self): obj1 = handle_respond.HandlerRespond(app_name=self.app_name) obj2 = handle_respond.HandlerRespond(self.app_name, "any action name") with self.assertRaises(KeyError): - async for command in obj1.run(self.test_payload, self.test_user1): - pass - result = [] - async for command in obj2.run(self.test_payload, self.test_user2): - result.append(command) - self.assertTrue(result == [10]) + await obj1.run(self.test_payload, self.test_user1) + self.assertTrue(await obj2.run(self.test_payload, self.test_user2) == [10]) diff --git a/tests/smart_kit_tests/handlers/test_handle_take_profile_data.py b/tests/smart_kit_tests/handlers/test_handle_take_profile_data.py index 1f5f0038..4609bec5 100644 --- a/tests/smart_kit_tests/handlers/test_handle_take_profile_data.py +++ b/tests/smart_kit_tests/handlers/test_handle_take_profile_data.py @@ -6,15 +6,15 @@ async def success(x): - yield "success" + return "success" async def fail(x): - yield "fail" + return "fail" async def timeout(x): - yield "timeout" + return "timeout" class MockVariables(Mock): @@ -40,8 +40,7 @@ def setUp(self): def behavior_outcome(result): async def outcome(x): - for command in result: - yield command + return result return outcome self.test_user = MagicMock('user', message=MagicMock(message_name="some_name"), variables=MockVariables(), @@ -57,18 +56,12 @@ async def test_handle_take_profile_data_init(self): async def test_handle_take_profile_data_run_fail(self): obj = HandlerTakeProfileData(self.app_name) payload = {"status_code": {"code": 102}} - result = [] - async for command in obj.run(payload, self.test_user): - result.append(command) - self.assertEqual(result, ["fail"]) + self.assertEqual(await obj.run(payload, self.test_user), ["fail"]) async def test_handle_take_profile_data_run_success(self): obj = HandlerTakeProfileData(self.app_name) payload = {"profile_data": {"geo": {"reverseGeocoding": {"country": "Российская Федерация"}, "location": {"lat": 10.125, "lon": 10.0124}}}, "status_code": {"code": 1}} - result = [] - async for command in obj.run(payload, self.test_user): - result.append(command) - self.assertEqual(result, ["success"]) + self.assertEqual(await obj.run(payload, self.test_user), ["success"]) self.assertEqual(self.test_user.variables.get("smart_geo"), payload["profile_data"]["geo"]) diff --git a/tests/smart_kit_tests/handlers/test_handler_text.py b/tests/smart_kit_tests/handlers/test_handler_text.py index 8bf16f47..1ebe88ab 100644 --- a/tests/smart_kit_tests/handlers/test_handler_text.py +++ b/tests/smart_kit_tests/handlers/test_handler_text.py @@ -7,12 +7,11 @@ async def mock_dialogue_manager1_run(x, y): - yield "TestAnswer" + return ["TestAnswer"], True async def mock_dialogue_manager2_run(x, y): - return - yield + return [], False class HandlerTest5(unittest.IsolatedAsyncioTestCase): @@ -51,24 +50,12 @@ def test_handler_text_init(self): async def test_handler_text_handle_base(self): obj1 = handler_text.HandlerText(self.app_name, self.test_dialog_manager1) obj2 = handler_text.HandlerText(self.app_name, self.test_dialog_manager2) - result = [] - async for command in obj1._handle_base(self.test_text_preprocessing_result, self.test_user): - result.append(command) - self.assertEqual(result, ["TestAnswer"]) - result = [] - async for command in obj2._handle_base(self.test_text_preprocessing_result, self.test_user): - result.append(command) - self.assertEqual(result, []) + self.assertEqual(await obj1._handle_base(self.test_text_preprocessing_result, self.test_user), ["TestAnswer"]) + self.assertEqual(await obj2._handle_base(self.test_text_preprocessing_result, self.test_user), []) async def test_handler_text_run(self): self.assertIsNotNone(handler_text.log_const.NORMALIZED_TEXT_VALUE) obj1 = handler_text.HandlerText(self.app_name, self.test_dialog_manager1) obj2 = handler_text.HandlerText(self.app_name, self.test_dialog_manager2) - result = [] - async for command in obj1.run(self.test_payload, self.test_user): - result.append(command) - self.assertEqual(result, ["TestAnswer"]) - result = [] - async for command in obj2.run(self.test_payload, self.test_user): - result.append(command) - self.assertEqual(result, []) + self.assertEqual(await obj1.run(self.test_payload, self.test_user), ["TestAnswer"]) + self.assertEqual(await obj2.run(self.test_payload, self.test_user), []) diff --git a/tests/smart_kit_tests/handlers/test_handler_timeout.py b/tests/smart_kit_tests/handlers/test_handler_timeout.py index 8625406e..5f06eeb9 100644 --- a/tests/smart_kit_tests/handlers/test_handler_timeout.py +++ b/tests/smart_kit_tests/handlers/test_handler_timeout.py @@ -7,7 +7,7 @@ async def mock_behaviors_timeout(x): - yield 120 + return [120] class HandlerTest2(unittest.IsolatedAsyncioTestCase): @@ -35,7 +35,4 @@ async def test_handler_timeout(self): obj = handler_timeout.HandlerTimeout(self.app_name) self.assertIsNotNone(obj.KAFKA_KEY) self.assertIsNotNone(handler_timeout.log_const.KEY_NAME) - result = [] - async for command in obj.run(self.test_payload, self.test_user): - result.append(command) - self.assertTrue(result == [120]) + self.assertTrue(await obj.run(self.test_payload, self.test_user) == [120]) diff --git a/tests/smart_kit_tests/models/test_dialogue_manager.py b/tests/smart_kit_tests/models/test_dialogue_manager.py index 368916c1..0e0d6c86 100644 --- a/tests/smart_kit_tests/models/test_dialogue_manager.py +++ b/tests/smart_kit_tests/models/test_dialogue_manager.py @@ -23,11 +23,11 @@ async def mock_scenario2_text_fits(): async def mock_scenario1_run(x, y): - yield x.name + y.name + return x.name + y.name async def mock_scenario2_run(x, y): - yield y.name + x.name + return y.name + x.name class ModelsTest1(unittest.IsolatedAsyncioTestCase): @@ -108,47 +108,29 @@ async def test_dialogue_manager_run(self): 'external_actions': {}}, self.app_name) # путь по умолчанию без выполнения условий - result = [] - async for command in obj1.run(self.test_text_preprocessing_result, self.test_user1): - result.append(command) - self.assertEqual( - result, ["TestNameResult"] + self.assertTrue( + await obj1.run(self.test_text_preprocessing_result, self.test_user1) == ("TestNameResult", True) ) - result = [] - async for command in obj2.run(self.test_text_preprocessing_result, self.test_user1): - result.append(command) - self.assertEqual( - result, ["TestNameResult"] + self.assertTrue( + await obj2.run(self.test_text_preprocessing_result, self.test_user1) == ("TestNameResult", True) ) # случай когда срабатоли оба условия - result = [] - async for command in obj1.run(self.test_text_preprocessing_result, self.test_user2): - result.append(command) - self.assertEqual( - result, ["TestNameResult"] + self.assertTrue( + await obj1.run(self.test_text_preprocessing_result, self.test_user2) == ("TestNameResult", True) ) # случай, когда 2-е условие не выполнено - result = [] - async for command in obj2.run(self.test_text_preprocessing_result, self.test_user3): - result.append(command) - self.assertEqual( - result, ['TestNameResult'] + self.assertTrue( + await obj2.run(self.test_text_preprocessing_result, self.test_user3) == ('TestNameResult', True) ) async def test_dialogue_manager_run_scenario(self): obj = dialogue_manager.DialogueManager({'scenarios': self.test_scenarios, 'external_actions': {'nothing_found_action': self.TestAction}}, self.app_name) - result = [] - async for command in obj.run_scenario(1, self.test_text_preprocessing_result, self.test_user1): - result.append(command) - self.assertEqual( - result, ["ResultTestName"] + self.assertTrue( + await obj.run_scenario(1, self.test_text_preprocessing_result, self.test_user1) == "ResultTestName" ) - result = [] - async for command in obj.run_scenario(2, self.test_text_preprocessing_result, self.test_user1): - result.append(command) - self.assertEqual( - result, ["TestNameResult"] + self.assertTrue( + await obj.run_scenario(2, self.test_text_preprocessing_result, self.test_user1) == "TestNameResult" ) diff --git a/tests/smart_kit_tests/system_answers/test_nothing_found_action.py b/tests/smart_kit_tests/system_answers/test_nothing_found_action.py index 587adb1a..6f6248e2 100644 --- a/tests/smart_kit_tests/system_answers/test_nothing_found_action.py +++ b/tests/smart_kit_tests/system_answers/test_nothing_found_action.py @@ -29,17 +29,9 @@ def test_system_answers_nothing_found_action_init(self): async def test_system_answer_nothing_found_action_run(self): obj1 = nothing_found_action.NothingFoundAction() obj2 = nothing_found_action.NothingFoundAction(self.test_items1, self.test_id) - result = [] - async for command in obj1.run(self.test_user1, self.test_text_preprocessing_result): - result.append(command) - command = result[0] self.assertTrue(isinstance( - command, Command) + (await obj1.run(self.test_user1, self.test_text_preprocessing_result)).pop(), Command) ) - result = [] - async for command in obj2.run(self.test_user1, self.test_text_preprocessing_result): - result.append(command) - command = result[0] self.assertTrue(isinstance( - command, Command) + (await obj2.run(self.test_user1, self.test_text_preprocessing_result)).pop(), Command) )